Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
290da8e0
Commit
290da8e0
authored
Jun 30, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Improved docstrings in examples
parent
6f3250cb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
8 deletions
+101
-8
examples/baseline_models/segformer.py
examples/baseline_models/segformer.py
+64
-2
examples/baseline_models/unet.py
examples/baseline_models/unet.py
+31
-0
examples/depth/train.py
examples/depth/train.py
+2
-2
examples/model_registry.py
examples/model_registry.py
+4
-4
No files found.
examples/baseline_models/segformer.py
View file @
290da8e0
...
...
@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
class
TransformerBlock
(
nn
.
Module
):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
class
Upsampling
(
nn
.
Module
):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def
__init__
(
self
,
in_shape
,
...
...
@@ -534,7 +596,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters
----------
-
----------
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
...
...
@@ -566,7 +628,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example
----------
-
----------
>>> model = Segformer(
... img_size=(128, 256),
... in_chans=3,
...
...
examples/baseline_models/unet.py
View file @
290da8e0
...
...
@@ -43,6 +43,37 @@ from functools import partial
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
self
,
in_shape
,
...
...
examples/depth/train.py
View file @
290da8e0
...
...
@@ -67,7 +67,7 @@ def count_parameters(model):
Count the number of trainable parameters in a model.
Parameters
----------
-
----------
model : torch.nn.Module
The model to count parameters for
...
...
@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
Saves model weights and gradients to a file for analysis.
Parameters
----------
-
----------
exp_dir : str
Experiment directory to save logs in
model : torch.nn.Module
...
...
examples/model_registry.py
View file @
290da8e0
...
...
@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
computational budgets and performance requirements.
Parameters
----------
-
----------
img_size : tuple, optional
Input image size as (height, width), by default (128, 256)
in_chans : int, optional
...
...
@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns
-------
-------
---
dict
Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters.
...
...
@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
- vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples
--------
--------
--
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes
-----
-----
-----
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment