Commit 313b1b73 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Corrected docstrings in _layers.py

parent e4879676
...@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT ...@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
""" """
Internal function to fill tensor with truncated normal distribution values. Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters Parameters
----------- -----------
tensor : torch.Tensor tensor : torch.Tensor
Tensor to fill with values Tensor to initialize
mean : float mean : float
Mean of the normal distribution Mean of the normal distribution
std : float std : float
...@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The filled tensor Initialized tensor
""" """
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x): def norm_cdf(x):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function # Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
...@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): ...@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@torch.jit.script @torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument. 'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return x
...@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
""" """
Initialize DropPath module. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters Parameters
----------- -----------
drop_prob : float, optional drop_prob : float, optional
Dropout probability, by default None Probability of dropping a path, by default None
""" """
def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x):
""" """
Forward pass with drop path. Forward pass with drop path regularization.
Parameters Parameters
----------- -----------
...@@ -177,7 +180,7 @@ class DropPath(nn.Module): ...@@ -177,7 +180,7 @@ class DropPath(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with potential drop path applied Output tensor with potential path dropping
""" """
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module): ...@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
""" """
Patch embedding layer for vision transformers. Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters Parameters
----------- -----------
img_size : tuple, optional img_size : tuple, optional
...@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module): ...@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, height, width) Input tensor of shape (batch_size, channels, height, width)
Returns Returns
------- -------
torch.Tensor torch.Tensor
Embedded patches with shape (batch, embed_dim, num_patches) Patch embeddings of shape (batch_size, embed_dim, num_patches)
""" """
# gather input # gather input
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -235,6 +241,9 @@ class MLP(nn.Module): ...@@ -235,6 +241,9 @@ class MLP(nn.Module):
""" """
Multi-layer perceptron with optional checkpointing. Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters Parameters
----------- -----------
in_features : int in_features : int
...@@ -252,7 +261,7 @@ class MLP(nn.Module): ...@@ -252,7 +261,7 @@ class MLP(nn.Module):
checkpointing : bool, optional checkpointing : bool, optional
Whether to use gradient checkpointing, by default False Whether to use gradient checkpointing, by default False
gain : float, optional gain : float, optional
Gain factor for output initialization, by default 1.0 Gain factor for weight initialization, by default 1.0
""" """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
...@@ -325,12 +334,10 @@ class MLP(nn.Module): ...@@ -325,12 +334,10 @@ class MLP(nn.Module):
class RealFFT2(nn.Module): class RealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap FFT similarly to the SHT.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None): This module provides a wrapper around PyTorch's real FFT2D that mimics
""" the interface of spherical harmonic transforms for consistency.
Initialize RealFFT2 module.
Parameters Parameters
----------- -----------
...@@ -339,10 +346,12 @@ class RealFFT2(nn.Module): ...@@ -339,10 +346,12 @@ class RealFFT2(nn.Module):
nlon : int nlon : int
Number of longitude points Number of longitude points
lmax : int, optional lmax : int, optional
Maximum l mode, by default None (same as nlat) Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1) Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(RealFFT2, self).__init__() super(RealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -352,17 +361,17 @@ class RealFFT2(nn.Module): ...@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of RealFFT2. Forward pass: compute real FFT2D.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon) Input tensor
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax) FFT coefficients
""" """
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2) y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
...@@ -371,12 +380,10 @@ class RealFFT2(nn.Module): ...@@ -371,12 +380,10 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module): class InverseRealFFT2(nn.Module):
""" """
Helper routine to wrap inverse FFT similarly to the SHT Helper routine to wrap inverse FFT similarly to the SHT.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None): This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
""" the interface of inverse spherical harmonic transforms for consistency.
Initialize InverseRealFFT2 module.
Parameters Parameters
----------- -----------
...@@ -385,10 +392,12 @@ class InverseRealFFT2(nn.Module): ...@@ -385,10 +392,12 @@ class InverseRealFFT2(nn.Module):
nlon : int nlon : int
Number of longitude points Number of longitude points
lmax : int, optional lmax : int, optional
Maximum l mode, by default None (same as nlat) Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1) Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None):
super(InverseRealFFT2, self).__init__() super(InverseRealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -398,29 +407,28 @@ class InverseRealFFT2(nn.Module): ...@@ -398,29 +407,28 @@ class InverseRealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of InverseRealFFT2. Forward pass: compute inverse real FFT2D.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, mmax) Input FFT coefficients
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon) Reconstructed spatial signal
""" """
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
""" """
Wrapper class that moves the channel dimension to the end Wrapper class that moves the channel dimension to the end.
"""
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): This module provides a layer normalization that works with channel-first
""" tensors by temporarily transposing the channel dimension to the end,
Initialize LayerNorm module. applying normalization, and then transposing back.
Parameters Parameters
----------- -----------
...@@ -435,8 +443,10 @@ class LayerNorm(nn.Module): ...@@ -435,8 +443,10 @@ class LayerNorm(nn.Module):
device : torch.device, optional device : torch.device, optional
Device to place the module on, by default None Device to place the module on, by default None
dtype : torch.dtype, optional dtype : torch.dtype, optional
Data type, by default None Data type for the module, by default None
""" """
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
super().__init__() super().__init__()
self.channel_dim = -3 self.channel_dim = -3
...@@ -445,31 +455,33 @@ class LayerNorm(nn.Module): ...@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of LayerNorm. Forward pass with channel dimension handling.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor with channel dimension at -3
Returns Returns
------- -------
torch.Tensor torch.Tensor
Normalized tensor Normalized tensor with same shape as input
""" """
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim) return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module): class SpectralConvS2(nn.Module):
""" """
Spectral convolution layer for spherical data. Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters Parameters
----------- -----------
forward_transform : nn.Module forward_transform : nn.Module
Forward transform (e.g., RealSHT) Forward transform (SHT or FFT)
inverse_transform : nn.Module inverse_transform : nn.Module
Inverse transform (e.g., InverseRealSHT) Inverse transform (ISHT or IFFT)
in_channels : int in_channels : int
Number of input channels Number of input channels
out_channels : int out_channels : int
...@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module): ...@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
gain : float, optional gain : float, optional
Gain factor for weight initialization, by default 2.0 Gain factor for weight initialization, by default 2.0
operator_type : str, optional operator_type : str, optional
Type of spectral operator, by default "driscoll-healy" Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
lr_scale_exponent : int, optional lr_scale_exponent : int, optional
Learning rate scale exponent, by default 0 Learning rate scaling exponent, by default 0
bias : bool, optional bias : bool, optional
Whether to use bias, by default False Whether to use bias, by default False
""" """
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False): def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super(SpectralConvS2, self).__init__() super().__init__()
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
self.in_channels = in_channels
self.out_channels = out_channels self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type self.operator_type = operator_type
self.lr_scale_exponent = lr_scale_exponent
# initialize the weights assert self.inverse_transform.lmax == self.modes_lat
scale = math.sqrt(gain / in_channels) assert self.inverse_transform.mmax == self.modes_lon
self.weight = nn.Parameter(scale * torch.randn(out_channels, in_channels, dtype=torch.cfloat))
if bias: weight_shape = [out_channels, in_channels]
self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.cfloat))
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
self.contract_func = "...ilm,oil->...olm"
else: else:
self.bias = None raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
scale = math.sqrt(gain / in_channels)
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
""" """
...@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module): ...@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
Returns Returns
------- -------
torch.Tensor tuple
Output tensor after spectral convolution Tuple containing (output, residual) tensors
""" """
# apply forward transform dtype = x.dtype
x = x.float()
residual = x
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# apply spectral convolution x = torch.einsum(self.contract_func, x, self.weight)
x = torch.einsum("bilm,oim->bolm", x, self.weight)
# apply inverse transform with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
# add bias if present if hasattr(self, "bias"):
if self.bias is not None: x = x + self.bias
x = x + self.bias.view(1, -1, 1, 1) x = x.type(dtype)
return x return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
""" """
Abstract base class for position embeddings on spherical data. Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters Parameters
----------- -----------
...@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): ...@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(PositionEmbedding, self).__init__() super().__init__()
self.img_shape = img_shape self.img_shape = img_shape
self.grid = grid
self.num_chans = num_chans self.num_chans = num_chans
@abc.abstractmethod
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
Abstract forward method for position embedding. Forward pass: add position embeddings to input.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
""" """
pass return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding): class SequencePositionEmbedding(PositionEmbedding):
""" """
Sequence-based position embedding for spherical data. Standard sequence-based position embedding.
This module adds position embeddings based on the sequence of latitude and longitude This module implements sinusoidal position embeddings similar to those
coordinates, providing spatial context to the model. used in the original Transformer paper, adapted for 2D spatial data.
Parameters Parameters
----------- -----------
...@@ -584,36 +626,27 @@ class SequencePositionEmbedding(PositionEmbedding): ...@@ -584,36 +626,27 @@ class SequencePositionEmbedding(PositionEmbedding):
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SequencePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create position embeddings
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1])
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_buffer("pos_embed", pos_embed)
def forward(self, x: torch.Tensor): with torch.no_grad():
""" # alternating custom position embeddings
Forward pass of sequence position embedding. pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
Parameters pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
-----------
x : torch.Tensor
Input tensor
Returns # register tensor
------- self.register_buffer("position_embeddings", pos_embed.float())
torch.Tensor
Tensor with position embeddings added
"""
return x + self.pos_embed
class SpectralPositionEmbedding(PositionEmbedding): class SpectralPositionEmbedding(PositionEmbedding):
r""" """
Spectral position embedding for spherical data. Spectral position embeddings for spherical transformers.
This module adds position embeddings in the spectral domain using spherical harmonics, This module creates position embeddings in the spectral domain using
providing spectral context to the model. spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters Parameters
----------- -----------
...@@ -626,37 +659,41 @@ class SpectralPositionEmbedding(PositionEmbedding): ...@@ -626,37 +659,41 @@ class SpectralPositionEmbedding(PositionEmbedding):
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SpectralPositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create spectral position embeddings # compute maximum required frequency and prepare isht
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1] // 2 + 1, dtype=torch.cfloat) lmax = math.floor(math.sqrt(self.num_chans)) + 1
nn.init.trunc_normal_(pos_embed.real, std=0.02) isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
nn.init.trunc_normal_(pos_embed.imag, std=0.02)
self.register_buffer("pos_embed", pos_embed)
def forward(self, x: torch.Tensor): # fill position embedding
""" with torch.no_grad():
Forward pass of spectral position embedding. pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
Parameters for i in range(self.num_chans):
----------- l = math.floor(math.sqrt(i))
x : torch.Tensor m = i - l**2 - l
Input tensor
Returns if m < 0:
------- pos_embed_freq[0, i, l, -m] = 1.0j
torch.Tensor else:
Tensor with spectral position embeddings added pos_embed_freq[0, i, l, m] = 1.0
"""
return x + self.pos_embed # compute spatial position embeddings
pos_embed = isht(pos_embed_freq)
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
# register tensor
self.register_buffer("position_embeddings", pos_embed)
class LearnablePositionEmbedding(PositionEmbedding): class LearnablePositionEmbedding(PositionEmbedding):
r""" """
Learnable position embedding for spherical data. Learnable position embeddings for spherical transformers.
This module adds learnable position embeddings that are optimized during training, This module provides learnable position embeddings that can be either
allowing the model to learn optimal spatial representations. latitude-only or full latitude-longitude embeddings.
Parameters Parameters
----------- -----------
...@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding): ...@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
num_chans : int, optional num_chans : int, optional
Number of channels, by default 1 Number of channels, by default 1
embed_type : str, optional embed_type : str, optional
Embedding type ("lat", "lon", or "both"), by default "lat" Embedding type ("lat" or "latlon"), by default "lat"
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super(LearnablePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
self.embed_type = embed_type
if embed_type == "lat":
# latitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], 1))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "lon":
# longitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, 1, img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "latlon":
# full lat-lon embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
else:
raise ValueError(f"Unknown embedding type {embed_type}")
def forward(self, x: torch.Tensor): if embed_type == "latlon":
""" self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
Forward pass of learnable position embedding. elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
Parameters else:
----------- raise ValueError(f"Unknown learnable position embedding type {embed_type}")
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return x + self.pos_embed
# class SpiralPositionEmbedding(PositionEmbedding): # class SpiralPositionEmbedding(PositionEmbedding):
# """ # """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment