# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # import abc import math import torch import torch.nn as nn import torch.fft from torch.utils.checkpoint import checkpoint from torch_harmonics import InverseRealSHT def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Parameters ----------- tensor: torch.Tensor an n-dimensional `torch.Tensor` mean: float the mean of the normal distribution std: float the standard deviation of the normal distribution a: float the minimum cutoff value, by default -2.0 b: float the maximum cutoff value Examples -------- >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) @torch.jit.script def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. Parameters ---------- x : torch.Tensor Input tensor drop_prob : float, optional Probability of dropping a path, by default 0.0 training : bool, optional Whether the model is in training mode, by default False Returns ------- torch.Tensor Output tensor """ if drop_prob == 0.0 or not training: return x keep_prob = 1.0 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This module implements stochastic depth regularization by randomly dropping entire residual paths during training, which helps with regularization and training of very deep networks. Parameters ---------- drop_prob : float, optional Probability of dropping a path, by default None """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class PatchEmbed(nn.Module): """ Patch embedding layer for vision transformers. This module splits input images into patches and projects them to a higher dimensional embedding space using convolutional layers. Parameters ---------- img_size : tuple, optional Input image size (height, width), by default (224, 224) patch_size : tuple, optional Patch size (height, width), by default (16, 16) in_chans : int, optional Number of input channels, by default 3 embed_dim : int, optional Embedding dimension, by default 768 """ def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): super(PatchEmbed, self).__init__() self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1])) num_patches = self.red_img_size[0] * self.red_img_size[1] self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) self.proj.weight.is_shared_mp = ["spatial"] self.proj.bias.is_shared_mp = ["spatial"] def forward(self, x): # gather input B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # new: B, C, H*W x = self.proj(x).flatten(2) return x class MLP(nn.Module): """ Multi-layer perceptron with optional checkpointing. This module implements a feed-forward network with two linear layers and an activation function, with optional dropout and gradient checkpointing. Parameters ---------- in_features : int Number of input features hidden_features : int, optional Number of hidden features, by default None (same as in_features) out_features : int, optional Number of output features, by default None (same as in_features) act_layer : nn.Module, optional Activation layer, by default nn.ReLU output_bias : bool, optional Whether to use bias in output layer, by default False drop_rate : float, optional Dropout rate, by default 0.0 checkpointing : bool, optional Whether to use gradient checkpointing, by default False gain : float, optional Gain factor for weight initialization, by default 1.0 """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0): super(MLP, self).__init__() self.checkpointing = checkpointing out_features = out_features or in_features hidden_features = hidden_features or in_features # Fist dense layer fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) # initialize the weights correctly scale = math.sqrt(2.0 / in_features) nn.init.normal_(fc1.weight, mean=0.0, std=scale) if fc1.bias is not None: nn.init.constant_(fc1.bias, 0.0) # activation act = act_layer() # output layer fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) # gain factor for the output determines the scaling of the output init scale = math.sqrt(gain / hidden_features) nn.init.normal_(fc2.weight, mean=0.0, std=scale) if fc2.bias is not None: nn.init.constant_(fc2.bias, 0.0) if drop_rate > 0.0: drop = nn.Dropout2d(drop_rate) self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) else: self.fwd = nn.Sequential(fc1, act, fc2) @torch.jit.ignore def checkpoint_forward(self, x): return checkpoint(self.fwd, x) def forward(self, x): if self.checkpointing: return self.checkpoint_forward(x) else: return self.fwd(x) class RealFFT2(nn.Module): """ Helper routine to wrap FFT similarly to the SHT. This module provides a wrapper around PyTorch's real FFT2D that mimics the interface of spherical harmonic transforms for consistency. Parameters ----------- nlat : int Number of latitude points nlon : int Number of longitude points lmax : int, optional Maximum spherical harmonic degree, by default None (same as nlat) mmax : int, optional Maximum spherical harmonic order, by default None (nlon//2 + 1) """ def __init__(self, nlat, nlon, lmax=None, mmax=None): super(RealFFT2, self).__init__() self.nlat = nlat self.nlon = nlon self.lmax = lmax or self.nlat self.mmax = mmax or self.nlon // 2 + 1 def forward(self, x): y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho") y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2) return y class InverseRealFFT2(nn.Module): """ Helper routine to wrap inverse FFT similarly to the SHT. This module provides a wrapper around PyTorch's inverse real FFT2D that mimics the interface of inverse spherical harmonic transforms for consistency. Parameters ----------- nlat : int Number of latitude points nlon : int Number of longitude points lmax : int, optional Maximum spherical harmonic degree, by default None (same as nlat) mmax : int, optional Maximum spherical harmonic order, by default None (nlon//2 + 1) """ def __init__(self, nlat, nlon, lmax=None, mmax=None): super(InverseRealFFT2, self).__init__() self.nlat = nlat self.nlon = nlon self.lmax = lmax or self.nlat self.mmax = mmax or self.nlon // 2 + 1 def forward(self, x): return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") class LayerNorm(nn.Module): """ Wrapper class that moves the channel dimension to the end. This module provides a layer normalization that works with channel-first tensors by temporarily transposing the channel dimension to the end, applying normalization, and then transposing back. Parameters ---------- in_channels : int Number of input channels eps : float, optional Epsilon for numerical stability, by default 1e-05 elementwise_affine : bool, optional Whether to use learnable affine parameters, by default True bias : bool, optional Whether to use bias, by default True device : torch.device, optional Device to place the module on, by default None dtype : torch.dtype, optional Data type for the module, by default None """ def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None): super().__init__() self.channel_dim = -3 self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype) def forward(self, x): return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim) class SpectralConvS2(nn.Module): """ Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic domain via the RealFFT2 and InverseRealFFT2 wrappers. Parameters ---------- forward_transform : nn.Module Forward transform (SHT or FFT) inverse_transform : nn.Module Inverse transform (ISHT or IFFT) in_channels : int Number of input channels out_channels : int Number of output channels gain : float, optional Gain factor for weight initialization, by default 2.0 operator_type : str, optional Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy" lr_scale_exponent : int, optional Learning rate scaling exponent, by default 0 bias : bool, optional Whether to use bias, by default False """ def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False): super().__init__() self.forward_transform = forward_transform self.inverse_transform = inverse_transform self.modes_lat = self.inverse_transform.lmax self.modes_lon = self.inverse_transform.mmax self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon) # remember factorization details self.operator_type = operator_type assert self.inverse_transform.lmax == self.modes_lat assert self.inverse_transform.mmax == self.modes_lon weight_shape = [out_channels, in_channels] if self.operator_type == "diagonal": weight_shape += [self.modes_lat, self.modes_lon] self.contract_func = "...ilm,oilm->...olm" elif self.operator_type == "block-diagonal": weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon] self.contract_func = "...ilm,oilnm->...oln" elif self.operator_type == "driscoll-healy": weight_shape += [self.modes_lat] self.contract_func = "...ilm,oil->...olm" else: raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") # form weight tensors scale = math.sqrt(gain / in_channels) self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64)) if bias: self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) def forward(self, x): dtype = x.dtype x = x.float() residual = x with torch.autocast(device_type="cuda", enabled=False): x = self.forward_transform(x) if self.scale_residual: residual = self.inverse_transform(x) x = torch.einsum(self.contract_func, x, self.weight) with torch.autocast(device_type="cuda", enabled=False): x = self.inverse_transform(x) if hasattr(self, "bias"): x = x + self.bias x = x.type(dtype) return x, residual class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): """ Abstract base class for position embeddings. This class defines the interface for position embedding modules that add positional information to input tensors. Parameters ---------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super().__init__() self.img_shape = img_shape self.num_chans = num_chans def forward(self, x: torch.Tensor): return x + self.position_embeddings class SequencePositionEmbedding(PositionEmbedding): """ Standard sequence-based position embedding. This module implements sinusoidal position embeddings similar to those used in the original Transformer paper, adapted for 2D spatial data. Parameters ---------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) with torch.no_grad(): # alternating custom position embeddings pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1) k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1) denom = torch.pow(10000, 2 * k / self.num_chans) pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom)) # register tensor self.register_buffer("position_embeddings", pos_embed.float()) class SpectralPositionEmbedding(PositionEmbedding): """ Spectral position embeddings for spherical transformers. This module creates position embeddings in the spectral domain using spherical harmonic functions, which are particularly suitable for spherical data processing. Parameters ----------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) # compute maximum required frequency and prepare isht lmax = math.floor(math.sqrt(self.num_chans)) + 1 isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid) # fill position embedding with torch.no_grad(): pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64) for i in range(self.num_chans): l = math.floor(math.sqrt(i)) m = i - l**2 - l if m < 0: pos_embed_freq[0, i, l, -m] = 1.0j else: pos_embed_freq[0, i, l, m] = 1.0 # compute spatial position embeddings pos_embed = isht(pos_embed_freq) # normalization pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True) # register tensor self.register_buffer("position_embeddings", pos_embed) class LearnablePositionEmbedding(PositionEmbedding): """ Learnable position embeddings for spherical transformers. This module provides learnable position embeddings that can be either latitude-only or full latitude-longitude embeddings. Parameters ---------- img_shape : tuple, optional Image shape (height, width), by default (480, 960) grid : str, optional Grid type, by default "equiangular" num_chans : int, optional Number of channels, by default 1 embed_type : str, optional Embedding type ("lat" or "latlon"), by default "lat" """ def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) if embed_type == "latlon": self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1])) elif embed_type == "lat": self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1)) else: raise ValueError(f"Unknown learnable position embedding type {embed_type}") # class SpiralPositionEmbedding(PositionEmbedding): # """ # Returns position embeddings on the torus # """ # def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): # super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) # with torch.no_grad(): # # alternating custom position embeddings # lats, _ = _precompute_latitudes(img_shape[0], grid=grid) # lats = lats.reshape(-1, 1) # lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1] # lons = lons.reshape(1, -1) # # channel index # k = torch.arange(self.num_chans).reshape(1, -1, 1, 1) # pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats))) # # register tensor # self.register_buffer("position_embeddings", pos_embed.float())