Commit 313b1b73 authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Corrected docstrings in _layers.py

parent e4879676
...@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT ...@@ -42,12 +42,15 @@ from torch_harmonics import InverseRealSHT
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
""" """
Internal function to fill tensor with truncated normal distribution values. Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters Parameters
----------- -----------
tensor : torch.Tensor tensor : torch.Tensor
Tensor to fill with values Tensor to initialize
mean : float mean : float
Mean of the normal distribution Mean of the normal distribution
std : float std : float
...@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -60,11 +63,24 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Returns Returns
------- -------
torch.Tensor torch.Tensor
The filled tensor Initialized tensor
""" """
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x): def norm_cdf(x):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function # Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
...@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): ...@@ -117,28 +133,12 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
@torch.jit.script @torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument. 'survival rate' as the argument.
Parameters
-----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Dropout probability, by default 0.0
training : bool, optional
Whether in training mode, by default False
Returns
-------
torch.Tensor
Output tensor with potential drop path applied
""" """
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return x return x
...@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - ...@@ -151,23 +151,26 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This module implements stochastic depth regularization by randomly dropping
entire residual paths during training, which helps with regularization and
training of very deep networks.
Parameters
-----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
"""
Initialize DropPath module.
Parameters
-----------
drop_prob : float, optional
Dropout probability, by default None
"""
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x):
""" """
Forward pass with drop path. Forward pass with drop path regularization.
Parameters Parameters
----------- -----------
...@@ -177,7 +180,7 @@ class DropPath(nn.Module): ...@@ -177,7 +180,7 @@ class DropPath(nn.Module):
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with potential drop path applied Output tensor with potential path dropping
""" """
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module): ...@@ -186,6 +189,9 @@ class PatchEmbed(nn.Module):
""" """
Patch embedding layer for vision transformers. Patch embedding layer for vision transformers.
This module splits input images into patches and projects them to a
higher dimensional embedding space using convolutional layers.
Parameters Parameters
----------- -----------
img_size : tuple, optional img_size : tuple, optional
...@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module): ...@@ -216,12 +222,12 @@ class PatchEmbed(nn.Module):
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, height, width) Input tensor of shape (batch_size, channels, height, width)
Returns Returns
------- -------
torch.Tensor torch.Tensor
Embedded patches with shape (batch, embed_dim, num_patches) Patch embeddings of shape (batch_size, embed_dim, num_patches)
""" """
# gather input # gather input
B, C, H, W = x.shape B, C, H, W = x.shape
...@@ -235,6 +241,9 @@ class MLP(nn.Module): ...@@ -235,6 +241,9 @@ class MLP(nn.Module):
""" """
Multi-layer perceptron with optional checkpointing. Multi-layer perceptron with optional checkpointing.
This module implements a feed-forward network with two linear layers
and an activation function, with optional dropout and gradient checkpointing.
Parameters Parameters
----------- -----------
in_features : int in_features : int
...@@ -252,7 +261,7 @@ class MLP(nn.Module): ...@@ -252,7 +261,7 @@ class MLP(nn.Module):
checkpointing : bool, optional checkpointing : bool, optional
Whether to use gradient checkpointing, by default False Whether to use gradient checkpointing, by default False
gain : float, optional gain : float, optional
Gain factor for output initialization, by default 1.0 Gain factor for weight initialization, by default 1.0
""" """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
...@@ -325,24 +334,24 @@ class MLP(nn.Module): ...@@ -325,24 +334,24 @@ class MLP(nn.Module):
class RealFFT2(nn.Module): class RealFFT2(nn.Module):
""" """
Helper routine to wrap FFT similarly to the SHT Helper routine to wrap FFT similarly to the SHT.
This module provides a wrapper around PyTorch's real FFT2D that mimics
the interface of spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None): def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize RealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(RealFFT2, self).__init__() super(RealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -352,17 +361,17 @@ class RealFFT2(nn.Module): ...@@ -352,17 +361,17 @@ class RealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of RealFFT2. Forward pass: compute real FFT2D.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon) Input tensor
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with shape (batch, channels, nlat, mmax) FFT coefficients
""" """
y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2) y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
...@@ -371,24 +380,24 @@ class RealFFT2(nn.Module): ...@@ -371,24 +380,24 @@ class RealFFT2(nn.Module):
class InverseRealFFT2(nn.Module): class InverseRealFFT2(nn.Module):
""" """
Helper routine to wrap inverse FFT similarly to the SHT Helper routine to wrap inverse FFT similarly to the SHT.
This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
the interface of inverse spherical harmonic transforms for consistency.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum spherical harmonic degree, by default None (same as nlat)
mmax : int, optional
Maximum spherical harmonic order, by default None (nlon//2 + 1)
""" """
def __init__(self, nlat, nlon, lmax=None, mmax=None): def __init__(self, nlat, nlon, lmax=None, mmax=None):
"""
Initialize InverseRealFFT2 module.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
lmax : int, optional
Maximum l mode, by default None (same as nlat)
mmax : int, optional
Maximum m mode, by default None (nlon // 2 + 1)
"""
super(InverseRealFFT2, self).__init__() super(InverseRealFFT2, self).__init__()
self.nlat = nlat self.nlat = nlat
...@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module): ...@@ -398,45 +407,46 @@ class InverseRealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of InverseRealFFT2. Forward pass: compute inverse real FFT2D.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor with shape (batch, channels, nlat, mmax) Input FFT coefficients
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor with shape (batch, channels, nlat, nlon) Reconstructed spatial signal
""" """
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
""" """
Wrapper class that moves the channel dimension to the end Wrapper class that moves the channel dimension to the end.
This module provides a layer normalization that works with channel-first
tensors by temporarily transposing the channel dimension to the end,
applying normalization, and then transposing back.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type for the module, by default None
""" """
def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
"""
Initialize LayerNorm module.
Parameters
-----------
in_channels : int
Number of input channels
eps : float, optional
Epsilon for numerical stability, by default 1e-05
elementwise_affine : bool, optional
Whether to use learnable affine parameters, by default True
bias : bool, optional
Whether to use bias, by default True
device : torch.device, optional
Device to place the module on, by default None
dtype : torch.dtype, optional
Data type, by default None
"""
super().__init__() super().__init__()
self.channel_dim = -3 self.channel_dim = -3
...@@ -445,31 +455,33 @@ class LayerNorm(nn.Module): ...@@ -445,31 +455,33 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Forward pass of LayerNorm. Forward pass with channel dimension handling.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor with channel dimension at -3
Returns Returns
------- -------
torch.Tensor torch.Tensor
Normalized tensor Normalized tensor with same shape as input
""" """
return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim) return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)
class SpectralConvS2(nn.Module): class SpectralConvS2(nn.Module):
""" """
Spectral convolution layer for spherical data. Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters Parameters
----------- -----------
forward_transform : nn.Module forward_transform : nn.Module
Forward transform (e.g., RealSHT) Forward transform (SHT or FFT)
inverse_transform : nn.Module inverse_transform : nn.Module
Inverse transform (e.g., InverseRealSHT) Inverse transform (ISHT or IFFT)
in_channels : int in_channels : int
Number of input channels Number of input channels
out_channels : int out_channels : int
...@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module): ...@@ -477,31 +489,49 @@ class SpectralConvS2(nn.Module):
gain : float, optional gain : float, optional
Gain factor for weight initialization, by default 2.0 Gain factor for weight initialization, by default 2.0
operator_type : str, optional operator_type : str, optional
Type of spectral operator, by default "driscoll-healy" Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
lr_scale_exponent : int, optional lr_scale_exponent : int, optional
Learning rate scale exponent, by default 0 Learning rate scaling exponent, by default 0
bias : bool, optional bias : bool, optional
Whether to use bias, by default False Whether to use bias, by default False
""" """
def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False): def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
super(SpectralConvS2, self).__init__() super().__init__()
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
self.in_channels = in_channels
self.out_channels = out_channels self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type self.operator_type = operator_type
self.lr_scale_exponent = lr_scale_exponent
# initialize the weights assert self.inverse_transform.lmax == self.modes_lat
scale = math.sqrt(gain / in_channels) assert self.inverse_transform.mmax == self.modes_lon
self.weight = nn.Parameter(scale * torch.randn(out_channels, in_channels, dtype=torch.cfloat))
if bias: weight_shape = [out_channels, in_channels]
self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.cfloat))
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
self.contract_func = "...ilm,oil->...olm"
else: else:
self.bias = None raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
scale = math.sqrt(gain / in_channels)
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x): def forward(self, x):
""" """
...@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module): ...@@ -514,28 +544,36 @@ class SpectralConvS2(nn.Module):
Returns Returns
------- -------
torch.Tensor tuple
Output tensor after spectral convolution Tuple containing (output, residual) tensors
""" """
# apply forward transform dtype = x.dtype
x = self.forward_transform(x) x = x.float()
residual = x
# apply spectral convolution with torch.autocast(device_type="cuda", enabled=False):
x = torch.einsum("bilm,oim->bolm", x, self.weight) x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
# apply inverse transform x = torch.einsum(self.contract_func, x, self.weight)
x = self.inverse_transform(x)
# add bias if present with torch.autocast(device_type="cuda", enabled=False):
if self.bias is not None: x = self.inverse_transform(x)
x = x + self.bias.view(1, -1, 1, 1)
return x if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
""" """
Abstract base class for position embeddings on spherical data. Abstract base class for position embeddings.
This class defines the interface for position embedding modules
that add positional information to input tensors.
Parameters Parameters
----------- -----------
...@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): ...@@ -548,30 +586,34 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(PositionEmbedding, self).__init__() super().__init__()
self.img_shape = img_shape self.img_shape = img_shape
self.grid = grid
self.num_chans = num_chans self.num_chans = num_chans
@abc.abstractmethod
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
Abstract forward method for position embedding. Forward pass: add position embeddings to input.
Parameters Parameters
----------- -----------
x : torch.Tensor x : torch.Tensor
Input tensor Input tensor
Returns
-------
torch.Tensor
Input tensor with position embeddings added
""" """
pass return x + self.position_embeddings
class SequencePositionEmbedding(PositionEmbedding): class SequencePositionEmbedding(PositionEmbedding):
""" """
Sequence-based position embedding for spherical data. Standard sequence-based position embedding.
This module adds position embeddings based on the sequence of latitude and longitude This module implements sinusoidal position embeddings similar to those
coordinates, providing spatial context to the model. used in the original Transformer paper, adapted for 2D spatial data.
Parameters Parameters
----------- -----------
...@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding): ...@@ -582,38 +624,29 @@ class SequencePositionEmbedding(PositionEmbedding):
num_chans : int, optional num_chans : int, optional
Number of channels, by default 1 Number of channels, by default 1
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SequencePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create position embeddings with torch.no_grad():
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1]) # alternating custom position embeddings
nn.init.trunc_normal_(pos_embed, std=0.02) pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
self.register_buffer("pos_embed", pos_embed) k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
denom = torch.pow(10000, 2 * k / self.num_chans)
def forward(self, x: torch.Tensor): pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))
"""
Forward pass of sequence position embedding. # register tensor
self.register_buffer("position_embeddings", pos_embed.float())
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with position embeddings added
"""
return x + self.pos_embed
class SpectralPositionEmbedding(PositionEmbedding): class SpectralPositionEmbedding(PositionEmbedding):
r""" """
Spectral position embedding for spherical data. Spectral position embeddings for spherical transformers.
This module adds position embeddings in the spectral domain using spherical harmonics, This module creates position embeddings in the spectral domain using
providing spectral context to the model. spherical harmonic functions, which are particularly suitable for
spherical data processing.
Parameters Parameters
----------- -----------
...@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding): ...@@ -624,39 +657,43 @@ class SpectralPositionEmbedding(PositionEmbedding):
num_chans : int, optional num_chans : int, optional
Number of channels, by default 1 Number of channels, by default 1
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super(SpectralPositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
# create spectral position embeddings # compute maximum required frequency and prepare isht
pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1] // 2 + 1, dtype=torch.cfloat) lmax = math.floor(math.sqrt(self.num_chans)) + 1
nn.init.trunc_normal_(pos_embed.real, std=0.02) isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
nn.init.trunc_normal_(pos_embed.imag, std=0.02)
self.register_buffer("pos_embed", pos_embed)
def forward(self, x: torch.Tensor): # fill position embedding
""" with torch.no_grad():
Forward pass of spectral position embedding. pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)
Parameters for i in range(self.num_chans):
----------- l = math.floor(math.sqrt(i))
x : torch.Tensor m = i - l**2 - l
Input tensor
if m < 0:
Returns pos_embed_freq[0, i, l, -m] = 1.0j
------- else:
torch.Tensor pos_embed_freq[0, i, l, m] = 1.0
Tensor with spectral position embeddings added
""" # compute spatial position embeddings
return x + self.pos_embed pos_embed = isht(pos_embed_freq)
# normalization
pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)
# register tensor
self.register_buffer("position_embeddings", pos_embed)
class LearnablePositionEmbedding(PositionEmbedding): class LearnablePositionEmbedding(PositionEmbedding):
r""" """
Learnable position embedding for spherical data. Learnable position embeddings for spherical transformers.
This module adds learnable position embeddings that are optimized during training, This module provides learnable position embeddings that can be either
allowing the model to learn optimal spatial representations. latitude-only or full latitude-longitude embeddings.
Parameters Parameters
----------- -----------
...@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding): ...@@ -667,47 +704,18 @@ class LearnablePositionEmbedding(PositionEmbedding):
num_chans : int, optional num_chans : int, optional
Number of channels, by default 1 Number of channels, by default 1
embed_type : str, optional embed_type : str, optional
Embedding type ("lat", "lon", or "both"), by default "lat" Embedding type ("lat" or "latlon"), by default "lat"
""" """
def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super(LearnablePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
self.embed_type = embed_type
if embed_type == "lat":
# latitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], 1))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "lon":
# longitude embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, 1, img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
elif embed_type == "latlon":
# full lat-lon embedding
pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], img_shape[1]))
nn.init.trunc_normal_(pos_embed, std=0.02)
self.register_parameter("pos_embed", pos_embed)
else:
raise ValueError(f"Unknown embedding type {embed_type}")
def forward(self, x: torch.Tensor): if embed_type == "latlon":
""" self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
Forward pass of learnable position embedding. elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
Parameters else:
----------- raise ValueError(f"Unknown learnable position embedding type {embed_type}")
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Tensor with learnable position embeddings added
"""
return x + self.pos_embed
# class SpiralPositionEmbedding(PositionEmbedding): # class SpiralPositionEmbedding(PositionEmbedding):
# """ # """
...@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding): ...@@ -731,4 +739,4 @@ class LearnablePositionEmbedding(PositionEmbedding):
# pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats))) # pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))
# # register tensor # # register tensor
# self.register_buffer("position_embeddings", pos_embed.float()) # self.register_buffer("position_embeddings", pos_embed.float())
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment