Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
6f3250cb
Commit
6f3250cb
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in torch_harmonics/examples
parent
50ebe96f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
262 additions
and
50 deletions
+262
-50
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+13
-4
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+45
-23
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+48
-8
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+42
-8
torch_harmonics/examples/pde_dataset.py
torch_harmonics/examples/pde_dataset.py
+32
-1
torch_harmonics/examples/shallow_water_equations.py
torch_harmonics/examples/shallow_water_equations.py
+38
-3
torch_harmonics/examples/stanford_2d3ds_dataset.py
torch_harmonics/examples/stanford_2d3ds_dataset.py
+44
-3
No files found.
torch_harmonics/examples/losses.py
View file @
6f3250cb
...
@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase):
...
@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude
Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized.
directions using FFT, then constructing 3D normal vectors that are normalized.
Args:
Parameters
nlat (int): Number of latitude points
----------
nlon (int): Number of longitude points
nlat : int
grid (str, optional): Grid type. Defaults to "equiangular".
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
"""
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
...
...
torch_harmonics/examples/models/_layers.py
View file @
6f3250cb
...
@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
...
@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with values outside :math:`[a, b]` redrawn until they are within
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
Parameters
mean: the mean of the normal distribution
-----------
std: the standard deviation of the normal distribution
tensor: torch.Tensor
a: the minimum cutoff value
an n-dimensional `torch.Tensor`
b: the maximum cutoff value
mean: float
Examples:
the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5)
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
>>> nn.init.trunc_normal_(w)
"""
"""
...
@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
...
@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
"""
"""
if
drop_prob
==
0.0
or
not
training
:
if
drop_prob
==
0.0
or
not
training
:
return
x
return
x
...
@@ -159,7 +181,7 @@ class DropPath(nn.Module):
...
@@ -159,7 +181,7 @@ class DropPath(nn.Module):
training of very deep networks.
training of very deep networks.
Parameters
Parameters
----------
-
----------
drop_prob : float, optional
drop_prob : float, optional
Probability of dropping a path, by default None
Probability of dropping a path, by default None
"""
"""
...
@@ -173,7 +195,7 @@ class DropPath(nn.Module):
...
@@ -173,7 +195,7 @@ class DropPath(nn.Module):
Forward pass with drop path regularization.
Forward pass with drop path regularization.
Parameters
Parameters
-
----------
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
...
@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module):
...
@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module):
higher dimensional embedding space using convolutional layers.
higher dimensional embedding space using convolutional layers.
Parameters
Parameters
----------
-
----------
img_size : tuple, optional
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
patch_size : tuple, optional
...
@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module):
...
@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module):
Forward pass of patch embedding.
Forward pass of patch embedding.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
Input tensor of shape (batch_size, channels, height, width)
...
@@ -245,7 +267,7 @@ class MLP(nn.Module):
...
@@ -245,7 +267,7 @@ class MLP(nn.Module):
and an activation function, with optional dropout and gradient checkpointing.
and an activation function, with optional dropout and gradient checkpointing.
Parameters
Parameters
----------
-
----------
in_features : int
in_features : int
Number of input features
Number of input features
hidden_features : int, optional
hidden_features : int, optional
...
@@ -301,7 +323,7 @@ class MLP(nn.Module):
...
@@ -301,7 +323,7 @@ class MLP(nn.Module):
Forward pass with gradient checkpointing.
Forward pass with gradient checkpointing.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
...
@@ -317,7 +339,7 @@ class MLP(nn.Module):
...
@@ -317,7 +339,7 @@ class MLP(nn.Module):
Forward pass of the MLP.
Forward pass of the MLP.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
...
@@ -364,7 +386,7 @@ class RealFFT2(nn.Module):
...
@@ -364,7 +386,7 @@ class RealFFT2(nn.Module):
Forward pass: compute real FFT2D.
Forward pass: compute real FFT2D.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
...
@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module):
...
@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module):
Forward pass: compute inverse real FFT2D.
Forward pass: compute inverse real FFT2D.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input FFT coefficients
Input FFT coefficients
...
@@ -431,7 +453,7 @@ class LayerNorm(nn.Module):
...
@@ -431,7 +453,7 @@ class LayerNorm(nn.Module):
applying normalization, and then transposing back.
applying normalization, and then transposing back.
Parameters
Parameters
----------
-
----------
in_channels : int
in_channels : int
Number of input channels
Number of input channels
eps : float, optional
eps : float, optional
...
@@ -458,7 +480,7 @@ class LayerNorm(nn.Module):
...
@@ -458,7 +480,7 @@ class LayerNorm(nn.Module):
Forward pass with channel dimension handling.
Forward pass with channel dimension handling.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor with channel dimension at -3
Input tensor with channel dimension at -3
...
@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module):
...
@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers.
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
Parameters
----------
-
----------
forward_transform : nn.Module
forward_transform : nn.Module
Forward transform (SHT or FFT)
Forward transform (SHT or FFT)
inverse_transform : nn.Module
inverse_transform : nn.Module
...
@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module):
...
@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module):
Forward pass of spectral convolution.
Forward pass of spectral convolution.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
...
@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
...
@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
that add positional information to input tensors.
that add positional information to input tensors.
Parameters
Parameters
----------
-
----------
img_shape : tuple, optional
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
Image shape (height, width), by default (480, 960)
grid : str, optional
grid : str, optional
...
@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding):
...
@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding):
used in the original Transformer paper, adapted for 2D spatial data.
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
Parameters
----------
-
----------
img_shape : tuple, optional
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
Image shape (height, width), by default (480, 960)
grid : str, optional
grid : str, optional
...
@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding):
...
@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding):
latitude-only or full latitude-longitude embeddings.
latitude-only or full latitude-longitude embeddings.
Parameters
Parameters
----------
-
----------
img_shape : tuple, optional
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
Image shape (height, width), by default (480, 960)
grid : str, optional
grid : str, optional
...
...
torch_harmonics/examples/models/lsno.py
View file @
6f3250cb
...
@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module):
reducing the spatial resolution while maintaining the spectral properties of the data.
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
Parameters
----------
-
----------
in_shape : tuple, optional
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
out_shape : tuple, optional
...
@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module):
Forward pass of the discrete-continuous encoder.
Forward pass of the discrete-continuous encoder.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Input tensor with shape (batch, channels, nlat, nlon)
...
@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module):
followed by discrete-continuous convolutions to restore spatial resolution.
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
Parameters
----------
-
----------
in_shape : tuple, optional
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
out_shape : tuple, optional
...
@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module):
Forward pass of the discrete-continuous decoder.
Forward pass of the discrete-continuous decoder.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Input tensor with shape (batch, channels, nlat, nlon)
...
@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module):
...
@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module):
class
SphericalNeuralOperatorBlock
(
nn
.
Module
):
class
SphericalNeuralOperatorBlock
(
nn
.
Module
):
"""
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
"""
def
__init__
(
def
__init__
(
...
@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module):
...
@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module):
as well as in the encoder and decoders.
as well as in the encoder and decoders.
Parameters
Parameters
----------
-
----------
img_size : tuple, optional
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
grid : str, optional
...
@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module):
...
@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Whether to use a bias, by default False
Whether to use a bias, by default False
Example
Example
----------
-
----------
>>> model = LocalSphericalNeuralOperator(
>>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256),
... img_shape=(128, 256),
... scale_factor=4,
... scale_factor=4,
...
@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module):
...
@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
torch.Size([1, 2, 128, 256])
References
References
----------
-
----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
ICML 2024, https://arxiv.org/pdf/2402.16845.
...
@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module):
...
@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Forward pass through the complete LSNO model.
Forward pass through the complete LSNO model.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Input tensor of shape (batch_size, in_chans, height, width)
...
...
torch_harmonics/examples/models/sfno.py
View file @
6f3250cb
...
@@ -43,6 +43,40 @@ from functools import partial
...
@@ -43,6 +43,40 @@ from functools import partial
class
SphericalFourierNeuralOperatorBlock
(
nn
.
Module
):
class
SphericalFourierNeuralOperatorBlock
(
nn
.
Module
):
"""
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
"""
def
__init__
(
def
__init__
(
...
@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
...
@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
Forward pass through the SFNO block.
Forward pass through the SFNO block.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
Returns
Returns
-------
-------
---
torch.Tensor
torch.Tensor
Output tensor after processing through the block
Output tensor after processing through the block
"""
"""
...
@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module):
Whether to use a bias, by default False
Whether to use a bias, by default False
Example:
Example:
--------
--------
--
>>> model = SphericalFourierNeuralOperator(
>>> model = SphericalFourierNeuralOperator(
... img_size=(128, 256),
... img_size=(128, 256),
... scale_factor=4,
... scale_factor=4,
...
@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
torch.Size([1, 2, 128, 256])
References
References
----------
-
----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
ICML 2023, https://arxiv.org/abs/2306.03838.
...
@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the feature extraction layers.
Forward pass through the feature extraction layers.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor
Input tensor
Returns
Returns
-------
-------
---
torch.Tensor
torch.Tensor
Features after processing through the network
Features after processing through the network
"""
"""
...
@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the complete SFNO model.
Forward pass through the complete SFNO model.
Parameters
Parameters
----------
-
----------
x : torch.Tensor
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Input tensor of shape (batch_size, in_chans, height, width)
Returns
Returns
-------
-------
---
torch.Tensor
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
Output tensor of shape (batch_size, out_chans, height, width)
"""
"""
...
...
torch_harmonics/examples/pde_dataset.py
View file @
6f3250cb
...
@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
...
@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
"""Custom Dataset class for PDE training data"""
"""Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
torch_harmonics/examples/shallow_water_equations.py
View file @
6f3250cb
...
@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module):
...
@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module):
"""
"""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
Parameters
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
----------
None
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Initial spectral coefficients for the Galewsky test case
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
"""
device
=
self
.
lap
.
device
device
=
self
.
lap
.
device
...
@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module):
...
@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module):
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
"""
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
"""
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
...
@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module):
...
@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module):
return
uspec
return
uspec
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
"""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon
=
2
*
torch
.
pi
/
self
.
nlon
dlon
=
2
*
torch
.
pi
/
self
.
nlon
radius
=
1
if
dimensionless
else
self
.
radius
radius
=
1
if
dimensionless
else
self
.
radius
if
polar_opt
>
0
:
if
polar_opt
>
0
:
...
...
torch_harmonics/examples/stanford_2d3ds_dataset.py
View file @
6f3250cb
...
@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
...
@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
"""
"""
Convenience class for downloading the 2d3ds dataset [1].
Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
https://arxiv.org/abs/1702.01105.
...
@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset):
...
@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset):
"""
"""
Spherical segmentation dataset from [1].
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
https://arxiv.org/abs/1702.01105.
...
@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset):
...
@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset):
"""
"""
Spherical segmentation dataset from [1].
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References
References
----------
-
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
https://arxiv.org/abs/1702.01105.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment