Commit 6f3250cb authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in torch_harmonics/examples

parent 50ebe96f
...@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase): ...@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized. directions using FFT, then constructing 3D normal vectors that are normalized.
Args: Parameters
nlat (int): Number of latitude points ----------
nlon (int): Number of longitude points nlat : int
grid (str, optional): Grid type. Defaults to "equiangular". Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
""" """
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"): def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
......
...@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): ...@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with values outside :math:`[a, b]` redrawn until they are within with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`. best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor` Parameters
mean: the mean of the normal distribution -----------
std: the standard deviation of the normal distribution tensor: torch.Tensor
a: the minimum cutoff value an n-dimensional `torch.Tensor`
b: the maximum cutoff value mean: float
Examples: the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5) >>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w) >>> nn.init.trunc_normal_(w)
""" """
...@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument. 'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return x
...@@ -159,7 +181,7 @@ class DropPath(nn.Module): ...@@ -159,7 +181,7 @@ class DropPath(nn.Module):
training of very deep networks. training of very deep networks.
Parameters Parameters
----------- ----------
drop_prob : float, optional drop_prob : float, optional
Probability of dropping a path, by default None Probability of dropping a path, by default None
""" """
...@@ -173,7 +195,7 @@ class DropPath(nn.Module): ...@@ -173,7 +195,7 @@ class DropPath(nn.Module):
Forward pass with drop path regularization. Forward pass with drop path regularization.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
...@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module): ...@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module):
higher dimensional embedding space using convolutional layers. higher dimensional embedding space using convolutional layers.
Parameters Parameters
----------- ----------
img_size : tuple, optional img_size : tuple, optional
Input image size (height, width), by default (224, 224) Input image size (height, width), by default (224, 224)
patch_size : tuple, optional patch_size : tuple, optional
...@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module): ...@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module):
Forward pass of patch embedding. Forward pass of patch embedding.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width) Input tensor of shape (batch_size, channels, height, width)
...@@ -245,7 +267,7 @@ class MLP(nn.Module): ...@@ -245,7 +267,7 @@ class MLP(nn.Module):
and an activation function, with optional dropout and gradient checkpointing. and an activation function, with optional dropout and gradient checkpointing.
Parameters Parameters
----------- ----------
in_features : int in_features : int
Number of input features Number of input features
hidden_features : int, optional hidden_features : int, optional
...@@ -301,7 +323,7 @@ class MLP(nn.Module): ...@@ -301,7 +323,7 @@ class MLP(nn.Module):
Forward pass with gradient checkpointing. Forward pass with gradient checkpointing.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
...@@ -317,7 +339,7 @@ class MLP(nn.Module): ...@@ -317,7 +339,7 @@ class MLP(nn.Module):
Forward pass of the MLP. Forward pass of the MLP.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
...@@ -364,7 +386,7 @@ class RealFFT2(nn.Module): ...@@ -364,7 +386,7 @@ class RealFFT2(nn.Module):
Forward pass: compute real FFT2D. Forward pass: compute real FFT2D.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
...@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module): ...@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module):
Forward pass: compute inverse real FFT2D. Forward pass: compute inverse real FFT2D.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input FFT coefficients Input FFT coefficients
...@@ -431,7 +453,7 @@ class LayerNorm(nn.Module): ...@@ -431,7 +453,7 @@ class LayerNorm(nn.Module):
applying normalization, and then transposing back. applying normalization, and then transposing back.
Parameters Parameters
----------- ----------
in_channels : int in_channels : int
Number of input channels Number of input channels
eps : float, optional eps : float, optional
...@@ -458,7 +480,7 @@ class LayerNorm(nn.Module): ...@@ -458,7 +480,7 @@ class LayerNorm(nn.Module):
Forward pass with channel dimension handling. Forward pass with channel dimension handling.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor with channel dimension at -3 Input tensor with channel dimension at -3
...@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module): ...@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers. domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters Parameters
----------- ----------
forward_transform : nn.Module forward_transform : nn.Module
Forward transform (SHT or FFT) Forward transform (SHT or FFT)
inverse_transform : nn.Module inverse_transform : nn.Module
...@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module): ...@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module):
Forward pass of spectral convolution. Forward pass of spectral convolution.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
...@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): ...@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
that add positional information to input tensors. that add positional information to input tensors.
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Image shape (height, width), by default (480, 960) Image shape (height, width), by default (480, 960)
grid : str, optional grid : str, optional
...@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding): ...@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding):
used in the original Transformer paper, adapted for 2D spatial data. used in the original Transformer paper, adapted for 2D spatial data.
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Image shape (height, width), by default (480, 960) Image shape (height, width), by default (480, 960)
grid : str, optional grid : str, optional
...@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding): ...@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding):
latitude-only or full latitude-longitude embeddings. latitude-only or full latitude-longitude embeddings.
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Image shape (height, width), by default (480, 960) Image shape (height, width), by default (480, 960)
grid : str, optional grid : str, optional
......
...@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module):
reducing the spatial resolution while maintaining the spectral properties of the data. reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters Parameters
----------- ----------
in_shape : tuple, optional in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440) Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional out_shape : tuple, optional
...@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module):
Forward pass of the discrete-continuous encoder. Forward pass of the discrete-continuous encoder.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon) Input tensor with shape (batch, channels, nlat, nlon)
...@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module):
followed by discrete-continuous convolutions to restore spatial resolution. followed by discrete-continuous convolutions to restore spatial resolution.
Parameters Parameters
----------- ----------
in_shape : tuple, optional in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960) Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional out_shape : tuple, optional
...@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module):
Forward pass of the discrete-continuous decoder. Forward pass of the discrete-continuous decoder.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon) Input tensor with shape (batch, channels, nlat, nlon)
...@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalNeuralOperatorBlock(nn.Module): class SphericalNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
def __init__( def __init__(
...@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module):
as well as in the encoder and decoders. as well as in the encoder and decoders.
Parameters Parameters
----------- ----------
img_size : tuple, optional img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256) Input image size (nlat, nlon), by default (128, 256)
grid : str, optional grid : str, optional
...@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Whether to use a bias, by default False Whether to use a bias, by default False
Example Example
----------- ----------
>>> model = LocalSphericalNeuralOperator( >>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256), ... img_shape=(128, 256),
... scale_factor=4, ... scale_factor=4,
...@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References References
----------- ----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.; .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024). "Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845. ICML 2024, https://arxiv.org/pdf/2402.16845.
...@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module): ...@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Forward pass through the complete LSNO model. Forward pass through the complete LSNO model.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width) Input tensor of shape (batch_size, in_chans, height, width)
......
...@@ -43,6 +43,40 @@ from functools import partial ...@@ -43,6 +43,40 @@ from functools import partial
class SphericalFourierNeuralOperatorBlock(nn.Module): class SphericalFourierNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
""" """
def __init__( def __init__(
...@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
Forward pass through the SFNO block. Forward pass through the SFNO block.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
Returns Returns
------- ----------
torch.Tensor torch.Tensor
Output tensor after processing through the block Output tensor after processing through the block
""" """
...@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module):
Whether to use a bias, by default False Whether to use a bias, by default False
Example: Example:
-------- ----------
>>> model = SphericalFourierNeuralOperator( >>> model = SphericalFourierNeuralOperator(
... img_size=(128, 256), ... img_size=(128, 256),
... scale_factor=4, ... scale_factor=4,
...@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References References
----------- ----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023). "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838. ICML 2023, https://arxiv.org/abs/2306.03838.
...@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the feature extraction layers. Forward pass through the feature extraction layers.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
Returns Returns
------- ----------
torch.Tensor torch.Tensor
Features after processing through the network Features after processing through the network
""" """
...@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module): ...@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the complete SFNO model. Forward pass through the complete SFNO model.
Parameters Parameters
----------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width) Input tensor of shape (batch_size, in_chans, height, width)
Returns Returns
------- ----------
torch.Tensor torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width) Output tensor of shape (batch_size, out_chans, height, width)
""" """
......
...@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver ...@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset): class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data""" """Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def __init__( def __init__(
self, self,
......
...@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module): ...@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module):
""" """
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440). Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations; Parameters
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf ----------
None
Returns Returns
------- -------
torch.Tensor torch.Tensor
Initial spectral coefficients for the Galewsky test case Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
""" """
device = self.lap.device device = self.lap.device
...@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module): ...@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module):
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor: def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
""" """
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps. Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
""" """
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device) dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
...@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module): ...@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module):
return uspec return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0): def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
"""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon = 2 * torch.pi / self.nlon dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius radius = 1 if dimensionless else self.radius
if polar_opt > 0: if polar_opt > 0:
......
...@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader: ...@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
""" """
Convenience class for downloading the 2d3ds dataset [1]. Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
...@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset): ...@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset):
""" """
Spherical segmentation dataset from [1]. Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
...@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset): ...@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset):
""" """
Spherical segmentation dataset from [1]. Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References References
----------- ----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.; .. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017). "Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105. https://arxiv.org/abs/1702.01105.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment