Commit 313b1b73 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Corrected docstrings in _layers.py

parent e4879676
......@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
"""
Internal function to fill tensor with truncated normal distribution values.
Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters
-----------
tensor : torch.Tensor
Tensor to fill with values
Tensor to initialize
mean : float
Mean of the normal distribution
std : float
......@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Returns
-------
torch.Tensor
The filled tensor
Initialized tensor
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
......@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
"""
if drop_prob == 0.0 or not training:
return x
......@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
-----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def __init__(self, drop_prob=None):
"""
Initialize DropPath module.
Parameters
-----------
drop_prob : float, optional
Dropout probability, by default None
"""
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
"""
Forward pass with drop path.
Forward pass with drop path regularization.
Parameters
-----------
......@@ -177,7 +180,7 @@ class DropPath(nn.Module):
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
Output tensor with potential path dropping
"""
return drop_path(x, self.drop_prob, self.training)
......@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
"""
Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters
-----------
img_size : tuple, optional
......@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, height, width)
Input tensor of shape (batch_size, channels, height, width)
Returns
-------
torch.Tensor
Embedded patches with shape (batch, embed_dim, num_patches)
Patch embeddings of shape (batch_size, embed_dim, num_patches)
"""
# gather input
B, C, H, W = x.shape
......@@ -235,6 +241,9 @@ class MLP(nn.Module):
"""
Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters
-----------
in_features : int
......@@ -252,7 +261,7 @@ class MLP(nn.Module):
checkpointing : bool, optional
Whether to use gradient checkpointing, by default False
gain : float, optional
Gain factor for output initialization, by default 1.0
Gain factor for weight initialization, by default 1.0
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
......@@ -325,24 +334,24 @@ class MLP(nn.Module):
class RealFFT2(nn.Module):
"""
Helper routine to wrap FFT similarly to the SHT
Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize RealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(RealFFT2, self).__init__()
self.nlat = nlat
......@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
def forward(self, x):
"""
Forward pass of RealFFT2.
Forward pass: compute real FFT2D.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
Input tensor
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax)
FFT coefficients
"""
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
......@@ -371,24 +380,24 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module):
"""
Helper routine to wrap inverse FFT similarly to the SHT
Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize InverseRealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(InverseRealFFT2, self).__init__()
self.nlat = nlat
......@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module):
def forward(self, x):
"""
Forward pass of InverseRealFFT2.
Forward pass: compute inverse real FFT2D.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, mmax)
Input FFT coefficients
Returns
-------
torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon)
Reconstructed spatial signal
"""
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module):
"""
Wrapper class that moves the channel dimension to the end
Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
"""
Initialize LayerNorm module.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type, by default None
"""
super().__init__()
self.channel_dim = -3
......@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
def forward(self, x):
"""
Forward pass of LayerNorm.
Forward pass with channel dimension handling.
Parameters
-----------
x : torch.Tensor
Input tensor
Input tensor with channel dimension at -3
Returns
-------
torch.Tensor
Normalized tensor
Normalized tensor with same shape as input
"""
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module):
"""
Spectral convolution layer for spherical data.
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
-----------
forward_transform : nn.Module
Forward transform (e.g., RealSHT)
Forward transform (SHT or FFT)
inverse_transform : nn.Module
Inverse transform (e.g., InverseRealSHT)
Inverse transform (ISHT or IFFT)
in_channels : int
Number of input channels
out_channels : int
......@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
gain : float, optional
Gain factor for weight initialization, by default 2.0
operator_type : str, optional
Type of spectral operator, by default "driscoll-healy"
Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
lr_scale_exponent : int, optional
Learning rate scale exponent, by default 0
Learning rate scaling exponent, by default 0
bias : bool, optional
Whether to use bias, by default False
"""
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super(SpectralConvS2, self).__init__()
super().__init__()
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.in_channels = in_channels
self.out_channels = out_channels
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type
self.lr_scale_exponent = lr_scale_exponent
# initialize the weights
scale = math.sqrt(gain / in_channels)
self.weight = nn.Parameter(scale * torch.randn(out_channels, in_channels, dtype=torch.cfloat))
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.cfloat))
weight_shape = [out_channels, in_channels]
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
self.contract_func = "...ilm,oil->...olm"
else:
self.bias = None
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
scale = math.sqrt(gain / in_channels)
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
"""
......@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
Returns
-------
torch.Tensor
Output tensor after spectral convolution
tuple
Tuple containing (output, residual) tensors
"""
# apply forward transform
x = self.forward_transform(x)
dtype = x.dtype
x = x.float()
residual = x
# apply spectral convolution
x = torch.einsum("bilm,oim->bolm", x, self.weight)
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# apply inverse transform
x = self.inverse_transform(x)
x = torch.einsum(self.contract_func, x, self.weight)
# add bias if present
if self.bias is not None:
x = x + self.bias.view(1, -1, 1, 1)
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
return x
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for position embeddings on spherical data.
Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters
-----------
......@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(PositionEmbedding, self).__init__()
super().__init__()
self.img_shape = img_shape
self.grid = grid
self.num_chans = num_chans
@abc.abstractmethod
def forward(self, x: torch.Tensor):
"""
Abstract forward method for position embedding.
Forward pass: add position embeddings to input.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
"""
pass
return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding):
"""
Sequence-based position embedding for spherical data.
Standard sequence-based position embedding.
This module adds position embeddings based on the sequence of latitude and longitude
coordinates, providing spatial context to the model.
This module implements sinusoidal position embeddings similar to those
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
-----------
......@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SequencePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create position embeddings
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1])
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_buffer("pos_embed", pos_embed)
with torch.no_grad():
# alternating custom position embeddings
pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
def forward(self, x: torch.Tensor):
"""
Forward pass of sequence position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with position embeddings added
"""
return x + self.pos_embed
pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
# register tensor
self.register_buffer("position_embeddings", pos_embed.float())
class SpectralPositionEmbedding(PositionEmbedding):
r"""
Spectral position embedding for spherical data.
"""
Spectral position embeddings for spherical transformers.
This module adds position embeddings in the spectral domain using spherical harmonics,
providing spectral context to the model.
This module creates position embeddings in the spectral domain using
spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters
-----------
......@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SpectralPositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create spectral position embeddings
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1] // 2 + 1, dtype=torch.cfloat)
nn.init.trunc_normal_(pos_embed.real, std=0.02)
nn.init.trunc_normal_(pos_embed.imag, std=0.02)
self.register_buffer("pos_embed", pos_embed)
# compute maximum required frequency and prepare isht
lmax = math.floor(math.sqrt(self.num_chans)) + 1
isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
def forward(self, x: torch.Tensor):
"""
Forward pass of spectral position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with spectral position embeddings added
"""
return x + self.pos_embed
# fill position embedding
with torch.no_grad():
pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
for i in range(self.num_chans):
l = math.floor(math.sqrt(i))
m = i - l**2 - l
if m < 0:
pos_embed_freq[0, i, l, -m] = 1.0j
else:
pos_embed_freq[0, i, l, m] = 1.0
# compute spatial position embeddings
pos_embed = isht(pos_embed_freq)
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
# register tensor
self.register_buffer("position_embeddings", pos_embed)
class LearnablePositionEmbedding(PositionEmbedding):
r"""
Learnable position embedding for spherical data.
"""
Learnable position embeddings for spherical transformers.
This module adds learnable position embeddings that are optimized during training,
allowing the model to learn optimal spatial representations.
This module provides learnable position embeddings that can be either
latitude-only or full latitude-longitude embeddings.
Parameters
-----------
......@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat", "lon", or "both"), by default "lat"
Embedding type ("lat" or "latlon"), by default "lat"
"""
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super(LearnablePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
self.embed_type = embed_type
if embed_type == "lat":
# latitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], 1))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "lon":
# longitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, 1, img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "latlon":
# full lat-lon embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
else:
raise ValueError(f"Unknown embedding type {embed_type}")
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
def forward(self, x: torch.Tensor):
"""
Forward pass of learnable position embedding.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return x + self.pos_embed
if embed_type == "latlon":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")
# class SpiralPositionEmbedding(PositionEmbedding):
# """
......@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor
# self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
# self.register_buffer("position_embeddings", pos_embed.float())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment