Commit 290da8e0 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in examples

parent 6f3250cb
...@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module): ...@@ -371,6 +371,41 @@ class AttentionWrapper(nn.Module):
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
"""
Transformer block with attention and MLP.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
mlp_hidden_channels : int
Number of hidden channels in MLP
nrep : int, optional
Number of repetitions of attention and MLP blocks, by default 1
heads : int, optional
Number of attention heads, by default 1
kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (3, 3)
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
att_drop_rate : float, optional
Attention dropout rate, by default 0.0
drop_path_rates : float or list, optional
Drop path rates for each block, by default 0.0
attention_mode : str, optional
Attention mode ("neighborhood", "global"), by default "neighborhood"
attn_kernel_shape : tuple, optional
Kernel shape for neighborhood attention, by default (7, 7)
bias : bool, optional
Whether to use bias, by default True
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module): ...@@ -493,6 +528,33 @@ class TransformerBlock(nn.Module):
class Upsampling(nn.Module): class Upsampling(nn.Module):
"""
Upsampling block for the Segformer model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int
Number of hidden channels in MLP
mlp_bias : bool, optional
Whether to use bias in MLP, by default True
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
conv_bias : bool, optional
Whether to use bias in convolution, by default False
activation : torch.nn.Module, optional
Activation function to use, by default nn.GELU
use_mlp : bool, optional
Whether to use MLP, by default False
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
...@@ -534,7 +596,7 @@ class Segformer(nn.Module): ...@@ -534,7 +596,7 @@ class Segformer(nn.Module):
Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks
Parameters Parameters
----------- ----------
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int kernel_shape: tuple, int
...@@ -566,7 +628,7 @@ class Segformer(nn.Module): ...@@ -566,7 +628,7 @@ class Segformer(nn.Module):
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Example Example
----------- ----------
>>> model = Segformer( >>> model = Segformer(
... img_size=(128, 256), ... img_size=(128, 256),
... in_chans=3, ... in_chans=3,
......
...@@ -43,6 +43,37 @@ from functools import partial ...@@ -43,6 +43,37 @@ from functools import partial
class DownsamplingBlock(nn.Module): class DownsamplingBlock(nn.Module):
"""
Downsampling block for the UNet model.
Parameters
----------
in_shape : tuple
Input shape (height, width)
out_shape : tuple
Output shape (height, width)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
nrep : int, optional
Number of repetitions of conv blocks, by default 1
kernel_shape : tuple, optional
Kernel shape for convolutions, by default (3, 3)
activation : callable, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connections, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.
drop_path_rate : float, optional
Drop path rate, by default 0.
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def __init__( def __init__(
self, self,
in_shape, in_shape,
......
...@@ -67,7 +67,7 @@ def count_parameters(model): ...@@ -67,7 +67,7 @@ def count_parameters(model):
Count the number of trainable parameters in a model. Count the number of trainable parameters in a model.
Parameters Parameters
----------- ----------
model : torch.nn.Module model : torch.nn.Module
The model to count parameters for The model to count parameters for
...@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1): ...@@ -87,7 +87,7 @@ def log_weights_and_grads(exp_dir, model, iters=1):
Saves model weights and gradients to a file for analysis. Saves model weights and gradients to a file for analysis.
Parameters Parameters
----------- ----------
exp_dir : str exp_dir : str
Experiment directory to save logs in Experiment directory to save logs in
model : torch.nn.Module model : torch.nn.Module
......
...@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p ...@@ -49,7 +49,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
computational budgets and performance requirements. computational budgets and performance requirements.
Parameters Parameters
----------- ----------
img_size : tuple, optional img_size : tuple, optional
Input image size as (height, width), by default (128, 256) Input image size as (height, width), by default (128, 256)
in_chans : int, optional in_chans : int, optional
...@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p ...@@ -64,7 +64,7 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular" Grid type for spherical models ("equiangular", "legendre-gauss", etc.), by default "equiangular"
Returns Returns
------- ----------
dict dict
Dictionary mapping model names to partial functions that can be called Dictionary mapping model names to partial functions that can be called
to instantiate the corresponding model with the specified parameters. to instantiate the corresponding model with the specified parameters.
...@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p ...@@ -96,13 +96,13 @@ def get_baseline_models(img_size=(128, 256), in_chans=3, out_chans=3, residual_p
- vit_sc2_layers4_e128: Vision Transformer variant (medium) - vit_sc2_layers4_e128: Vision Transformer variant (medium)
Examples Examples
-------- ----------
>>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1) >>> model_registry = get_baseline_models(img_size=(64, 128), in_chans=2, out_chans=1)
>>> model = model_registry['sfno_sc2_layers4_e32']() >>> model = model_registry['sfno_sc2_layers4_e32']()
>>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") >>> print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Notes Notes
----- ----------
- Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim} - Model names follow the pattern: {model_type}_{scale_factor}_{layers}_{embed_dim}
- 'sc2' indicates scale factor of 2 (downsampling by 2) - 'sc2' indicates scale factor of 2 (downsampling by 2)
- 'e32', 'e128', 'e256' indicate embedding dimensions - 'e32', 'e128', 'e256' indicate embedding dimensions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment