# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # import abc import math import torch import torch.nn as nn import torch.fft from torch.utils.checkpoint import checkpoint from torch_harmonics import InverseRealSHT def _no_grad_trunc_normal_(tensor, mean, std, a, b): """ Internal function to fill tensor with truncated normal distribution values. Parameters ----------- tensor : torch.Tensor Tensor to fill with values mean : float Mean of the normal distribution std : float Standard deviation of the normal distribution a : float Lower bound for truncation b : float Upper bound for truncation Returns ------- torch.Tensor The filled tensor """ # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) @torch.jit.script def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. Parameters ----------- x : torch.Tensor Input tensor drop_prob : float, optional Dropout probability, by default 0.0 training : bool, optional Whether in training mode, by default False Returns ------- torch.Tensor Output tensor with potential drop path applied """ if drop_prob == 0.0 or not training: return x keep_prob = 1.0 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): """ Initialize DropPath module. Parameters ----------- drop_prob : float, optional Dropout probability, by default None """ super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): """ Forward pass with drop path. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Output tensor with potential drop path applied """ return drop_path(x, self.drop_prob, self.training) class PatchEmbed(nn.Module): """ Patch embedding layer for vision transformers. Parameters ----------- img_size : tuple, optional Input image size (height, width), by default (224, 224) patch_size : tuple, optional Patch size (height, width), by default (16, 16) in_chans : int, optional Number of input channels, by default 3 embed_dim : int, optional Embedding dimension, by default 768 """ def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): super(PatchEmbed, self).__init__() self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1])) num_patches = self.red_img_size[0] * self.red_img_size[1] self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) self.proj.weight.is_shared_mp = ["spatial"] self.proj.bias.is_shared_mp = ["spatial"] def forward(self, x): """ Forward pass of patch embedding. Parameters ----------- x : torch.Tensor Input tensor with shape (batch, channels, height, width) Returns ------- torch.Tensor Embedded patches with shape (batch, embed_dim, num_patches) """ # gather input B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # new: B, C, H*W x = self.proj(x).flatten(2) return x class MLP(nn.Module): """ Multi-layer perceptron with optional checkpointing. Parameters ----------- in_features : int Number of input features hidden_features : int, optional Number of hidden features, by default None (same as in_features) out_features : int, optional Number of output features, by default None (same as in_features) act_layer : nn.Module, optional Activation layer, by default nn.ReLU output_bias : bool, optional Whether to use bias in output layer, by default False drop_rate : float, optional Dropout rate, by default 0.0 checkpointing : bool, optional Whether to use gradient checkpointing, by default False gain : float, optional Gain factor for output initialization, by default 1.0 """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0): super(MLP, self).__init__() self.checkpointing = checkpointing out_features = out_features or in_features hidden_features = hidden_features or in_features # Fist dense layer fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) # initialize the weights correctly scale = math.sqrt(2.0 / in_features) nn.init.normal_(fc1.weight, mean=0.0, std=scale) if fc1.bias is not None: nn.init.constant_(fc1.bias, 0.0) # activation act = act_layer() # output layer fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) # gain factor for the output determines the scaling of the output init scale = math.sqrt(gain / hidden_features) nn.init.normal_(fc2.weight, mean=0.0, std=scale) if fc2.bias is not None: nn.init.constant_(fc2.bias, 0.0) if drop_rate > 0.0: drop = nn.Dropout2d(drop_rate) self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) else: self.fwd = nn.Sequential(fc1, act, fc2) @torch.jit.ignore def checkpoint_forward(self, x): """ Forward pass with gradient checkpointing. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Output tensor """ return checkpoint(self.fwd, x) def forward(self, x): """ Forward pass of the MLP. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Output tensor """ if self.checkpointing: return self.checkpoint_forward(x) else: return self.fwd(x) class RealFFT2(nn.Module): """ Helper routine to wrap FFT similarly to the SHT """ def __init__(self, nlat, nlon, lmax=None, mmax=None): """ Initialize RealFFT2 module. Parameters ----------- nlat : int Number of latitude points nlon : int Number of longitude points lmax : int, optional Maximum l mode, by default None (same as nlat) mmax : int, optional Maximum m mode, by default None (nlon // 2 + 1) """ super(RealFFT2, self).__init__() self.nlat = nlat self.nlon = nlon self.lmax = lmax or self.nlat self.mmax = mmax or self.nlon // 2 + 1 def forward(self, x): """ Forward pass of RealFFT2. Parameters ----------- x : torch.Tensor Input tensor with shape (batch, channels, nlat, nlon) Returns ------- torch.Tensor Output tensor with shape (batch, channels, nlat, mmax) """ y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2) return y class InverseRealFFT2(nn.Module): """ Helper routine to wrap inverse FFT similarly to the SHT """ def __init__(self, nlat, nlon, lmax=None, mmax=None): """ Initialize InverseRealFFT2 module. Parameters ----------- nlat : int Number of latitude points nlon : int Number of longitude points lmax : int, optional Maximum l mode, by default None (same as nlat) mmax : int, optional Maximum m mode, by default None (nlon // 2 + 1) """ super(InverseRealFFT2, self).__init__() self.nlat = nlat self.nlon = nlon self.lmax = lmax or self.nlat self.mmax = mmax or self.nlon // 2 + 1 def forward(self, x): """ Forward pass of InverseRealFFT2. Parameters ----------- x : torch.Tensor Input tensor with shape (batch, channels, nlat, mmax) Returns ------- torch.Tensor Output tensor with shape (batch, channels, nlat, nlon) """ return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") class LayerNorm(nn.Module): """ Wrapper class that moves the channel dimension to the end """ def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): """ Initialize LayerNorm module. Parameters ----------- in_channels : int Number of input channels eps : float, optional Epsilon for numerical stability, by default 1e-05 elementwise_affine : bool, optional Whether to use learnable affine parameters, by default True bias : bool, optional Whether to use bias, by default True device : torch.device, optional Device to place the module on, by default None dtype : torch.dtype, optional Data type, by default None """ super().__init__() self.channel_dim = -3 self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype) def forward(self, x): """ Forward pass of LayerNorm. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Normalized tensor """ return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim) class SpectralConvS2(nn.Module): """ Spectral convolution layer for spherical data. Parameters ----------- forward_transform : nn.Module Forward transform (e.g., RealSHT) inverse_transform : nn.Module Inverse transform (e.g., InverseRealSHT) in_channels : int Number of input channels out_channels : int Number of output channels gain : float, optional Gain factor for weight initialization, by default 2.0 operator_type : str, optional Type of spectral operator, by default "driscoll-healy" lr_scale_exponent : int, optional Learning rate scale exponent, by default 0 bias : bool, optional Whether to use bias, by default False """ def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False): super(SpectralConvS2, self).__init__() self.forward_transform = forward_transform self.inverse_transform = inverse_transform self.in_channels = in_channels self.out_channels = out_channels self.operator_type = operator_type self.lr_scale_exponent = lr_scale_exponent # initialize the weights scale = math.sqrt(gain / in_channels) self.weight = nn.Parameter(scale * torch.randn(out_channels, in_channels, dtype=torch.cfloat)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels, dtype=torch.cfloat)) else: self.bias = None def forward(self, x): """ Forward pass of spectral convolution. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Output tensor after spectral convolution """ # apply forward transform x = self.forward_transform(x) # apply spectral convolution x = torch.einsum("bilm,oim->bolm", x, self.weight) # apply inverse transform x = self.inverse_transform(x) # add bias if present if self.bias is not None: x = x + self.bias.view(1, -1, 1, 1) return x class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): """ Abstract base class for position embeddings on spherical data. Parameters ----------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super(PositionEmbedding, self).__init__() self.img_shape = img_shape self.grid = grid self.num_chans = num_chans @abc.abstractmethod def forward(self, x: torch.Tensor): """ Abstract forward method for position embedding. Parameters ----------- x : torch.Tensor Input tensor """ pass class SequencePositionEmbedding(PositionEmbedding): """ Sequence-based position embedding for spherical data. This module adds position embeddings based on the sequence of latitude and longitude coordinates, providing spatial context to the model. Parameters ----------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super(SequencePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) # create position embeddings pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1]) nn.init.trunc_normal_(pos_embed, std=0.02) self.register_buffer("pos_embed", pos_embed) def forward(self, x: torch.Tensor): """ Forward pass of sequence position embedding. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Tensor with position embeddings added """ return x + self.pos_embed class SpectralPositionEmbedding(PositionEmbedding): r""" Spectral position embedding for spherical data. This module adds position embeddings in the spectral domain using spherical harmonics, providing spectral context to the model. Parameters ----------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super(SpectralPositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) # create spectral position embeddings pos_embed = torch.zeros(1, num_chans, img_shape[0], img_shape[1] // 2 + 1, dtype=torch.cfloat) nn.init.trunc_normal_(pos_embed.real, std=0.02) nn.init.trunc_normal_(pos_embed.imag, std=0.02) self.register_buffer("pos_embed", pos_embed) def forward(self, x: torch.Tensor): """ Forward pass of spectral position embedding. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Tensor with spectral position embeddings added """ return x + self.pos_embed class LearnablePositionEmbedding(PositionEmbedding): r""" Learnable position embedding for spherical data. This module adds learnable position embeddings that are optimized during training, allowing the model to learn optimal spatial representations. Parameters ----------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 embed_type : str, optional Embedding type ("lat", "lon", or "both"), by default "lat" """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): super(LearnablePositionEmbedding, self).__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) self.embed_type = embed_type if embed_type == "lat": # latitude embedding pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], 1)) nn.init.trunc_normal_(pos_embed, std=0.02) self.register_parameter("pos_embed", pos_embed) elif embed_type == "lon": # longitude embedding pos_embed = nn.Parameter(torch.zeros(1, num_chans, 1, img_shape[1])) nn.init.trunc_normal_(pos_embed, std=0.02) self.register_parameter("pos_embed", pos_embed) elif embed_type == "latlon": # full lat-lon embedding pos_embed = nn.Parameter(torch.zeros(1, num_chans, img_shape[0], img_shape[1])) nn.init.trunc_normal_(pos_embed, std=0.02) self.register_parameter("pos_embed", pos_embed) else: raise ValueError(f"Unknown embedding type {embed_type}") def forward(self, x: torch.Tensor): """ Forward pass of learnable position embedding. Parameters ----------- x : torch.Tensor Input tensor Returns ------- torch.Tensor Tensor with learnable position embeddings added """ return x + self.pos_embed # class SpiralPositionEmbedding(PositionEmbedding): # """ # Returns position embeddings on the torus # """ # def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): # super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) # with torch.no_grad(): # # alternating custom position embeddings # lats, _ = _precompute_latitudes(img_shape[0], grid=grid) # lats = lats.reshape(-1, 1) # lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1] # lons = lons.reshape(1, -1) # # channel index # k = torch.arange(self.num_chans).reshape(1, -1, 1, 1) # pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats))) # # register tensor # self.register_buffer("position_embeddings", pos_embed.float())