Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
6f3250cb
Commit
6f3250cb
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in torch_harmonics/examples
parent
50ebe96f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
262 additions
and
50 deletions
+262
-50
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+13
-4
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+45
-23
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+48
-8
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+42
-8
torch_harmonics/examples/pde_dataset.py
torch_harmonics/examples/pde_dataset.py
+32
-1
torch_harmonics/examples/shallow_water_equations.py
torch_harmonics/examples/shallow_water_equations.py
+38
-3
torch_harmonics/examples/stanford_2d3ds_dataset.py
torch_harmonics/examples/stanford_2d3ds_dataset.py
+44
-3
No files found.
torch_harmonics/examples/losses.py
View file @
6f3250cb
...
...
@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized.
Args:
nlat (int): Number of latitude points
nlon (int): Number of longitude points
grid (str, optional): Grid type. Defaults to "equiangular".
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
...
...
torch_harmonics/examples/models/_layers.py
View file @
6f3250cb
...
...
@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
Parameters
-----------
tensor: torch.Tensor
an n-dimensional `torch.Tensor`
mean: float
the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
...
...
@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
"""
if
drop_prob
==
0.0
or
not
training
:
return
x
...
...
@@ -159,7 +181,7 @@ class DropPath(nn.Module):
training of very deep networks.
Parameters
----------
-
----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
...
...
@@ -173,7 +195,7 @@ class DropPath(nn.Module):
Forward pass with drop path regularization.
Parameters
-
----------
----------
x : torch.Tensor
Input tensor
...
...
@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module):
higher dimensional embedding space using convolutional layers.
Parameters
----------
-
----------
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
...
...
@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module):
Forward pass of patch embedding.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
...
...
@@ -245,7 +267,7 @@ class MLP(nn.Module):
and an activation function, with optional dropout and gradient checkpointing.
Parameters
----------
-
----------
in_features : int
Number of input features
hidden_features : int, optional
...
...
@@ -301,7 +323,7 @@ class MLP(nn.Module):
Forward pass with gradient checkpointing.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
...
...
@@ -317,7 +339,7 @@ class MLP(nn.Module):
Forward pass of the MLP.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
...
...
@@ -364,7 +386,7 @@ class RealFFT2(nn.Module):
Forward pass: compute real FFT2D.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
...
...
@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module):
Forward pass: compute inverse real FFT2D.
Parameters
----------
-
----------
x : torch.Tensor
Input FFT coefficients
...
...
@@ -431,7 +453,7 @@ class LayerNorm(nn.Module):
applying normalization, and then transposing back.
Parameters
----------
-
----------
in_channels : int
Number of input channels
eps : float, optional
...
...
@@ -458,7 +480,7 @@ class LayerNorm(nn.Module):
Forward pass with channel dimension handling.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor with channel dimension at -3
...
...
@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
----------
-
----------
forward_transform : nn.Module
Forward transform (SHT or FFT)
inverse_transform : nn.Module
...
...
@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module):
Forward pass of spectral convolution.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
...
...
@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
that add positional information to input tensors.
Parameters
----------
-
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
...
...
@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding):
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
----------
-
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
...
...
@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding):
latitude-only or full latitude-longitude embeddings.
Parameters
----------
-
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
...
...
torch_harmonics/examples/models/lsno.py
View file @
6f3250cb
...
...
@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module):
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
----------
-
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
...
...
@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module):
Forward pass of the discrete-continuous encoder.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
...
...
@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module):
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
----------
-
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
...
...
@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module):
Forward pass of the discrete-continuous decoder.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
...
...
@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module):
class
SphericalNeuralOperatorBlock
(
nn
.
Module
):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def
__init__
(
...
...
@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module):
as well as in the encoder and decoders.
Parameters
----------
-
----------
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
...
...
@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Whether to use a bias, by default False
Example
----------
-
----------
>>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256),
... scale_factor=4,
...
...
@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
----------
-
----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
...
...
@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Forward pass through the complete LSNO model.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
...
...
torch_harmonics/examples/models/sfno.py
View file @
6f3250cb
...
...
@@ -43,6 +43,40 @@ from functools import partial
class
SphericalFourierNeuralOperatorBlock
(
nn
.
Module
):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def
__init__
(
...
...
@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
Forward pass through the SFNO block.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
Returns
-------
-------
---
torch.Tensor
Output tensor after processing through the block
"""
...
...
@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module):
Whether to use a bias, by default False
Example:
--------
--------
--
>>> model = SphericalFourierNeuralOperator(
... img_size=(128, 256),
... scale_factor=4,
...
...
@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
----------
-
----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
...
...
@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the feature extraction layers.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor
Returns
-------
-------
---
torch.Tensor
Features after processing through the network
"""
...
...
@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the complete SFNO model.
Parameters
----------
-
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
-------
---
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
...
...
torch_harmonics/examples/pde_dataset.py
View file @
6f3250cb
...
...
@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Custom Dataset class for PDE training data"""
"""Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def
__init__
(
self
,
...
...
torch_harmonics/examples/shallow_water_equations.py
View file @
6f3250cb
...
...
@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module):
"""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Parameters
----------
None
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device
=
self
.
lap
.
device
...
...
@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module):
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
...
...
@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module):
return
uspec
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
"""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon
=
2
*
torch
.
pi
/
self
.
nlon
radius
=
1
if
dimensionless
else
self
.
radius
if
polar_opt
>
0
:
...
...
torch_harmonics/examples/stanford_2d3ds_dataset.py
View file @
6f3250cb
...
...
@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
"""
Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
...
...
@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
...
...
@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment