Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
290da8e0
Commit
290da8e0
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in examples
parent
6f3250cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
8 deletions
+101
-8
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+64
-2
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+31
-0
examples/depth/train.py
examples/depth/train.py
+2
-2
examples/model_registry.py
examples/model_registry.py
+4
-4
No files found.
examples/baseline_models/segformer.py
View file @
290da8e0
...
@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
...
@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
...
@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -534,7 +596,7 @@ class Segformer(nn.Module):
...
@@ -534,7 +596,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
Parameters
----------
-
----------
img_shape : tuple, optional
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
kernel_shape: tuple, int
...
@@ -566,7 +628,7 @@ class Segformer(nn.Module):
...
@@ -566,7 +628,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
Example
----------
-
----------
>>> model = Segformer(
>>> model = Segformer(
... img_size=(128, 256),
... img_size=(128, 256),
... in_chans=3,
... in_chans=3,
...
...
examples/baseline_models/unet.py
View file @
290da8e0
...
@@ -43,6 +43,37 @@ from functools import partial
...
@@ -43,6 +43,37 @@ from functools import partial
class
DownsamplingBlock
(
nn
.
Module
):
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
...
examples/depth/train.py
View file @
290da8e0
...
@@ -67,7 +67,7 @@ def count_parameters(model):
...
@@ -67,7 +67,7 @@ def count_parameters(model):
Count the number of trainable parameters in a model.
Count the number of trainable parameters in a model.
Parameters
Parameters
----------
-
----------
model : torch.nn.Module
model : torch.nn.Module
The model to count parameters for
The model to count parameters for
...
@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
...
@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
Saves model weights and gradients to a file for analysis.
Saves model weights and gradients to a file for analysis.
Parameters
Parameters
----------
-
----------
exp_dir : str
exp_dir : str
Experiment directory to save logs in
Experiment directory to save logs in
model : torch.nn.Module
model : torch.nn.Module
...
...
examples/model_registry.py
View file @
290da8e0
...
@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
...
@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
computational budgets and performance requirements.
computational budgets and performance requirements.
Parameters
Parameters
----------
-
----------
img_size : tuple, optional
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
in_chans : int, optional
...
@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
...
@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
Returns
-------
-------
---
dict
dict
Dictionary mapping model names to partial functions that can be called
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
to instantiate the corresponding model with the specified parameters.
...
@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
...
@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
Examples
--------
--------
--
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
Notes
-----
-----
-----
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
- 'e32', 'e128', 'e256' indicate embedding dimensions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment