# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # from typing import List, Tuple, Union, Optional import math import torch import torch.nn as nn from functools import partial from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics._disco_convolution import ( _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch, _disco_s2_contraction_triton, _disco_s2_transpose_contraction_triton, ) def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ # compute the support dtheta = (theta_cutoff - 0.0) / ntheta ikernel = torch.arange(ntheta).reshape(-1, 1, 1) itheta = ikernel * dtheta norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) # find the indices where the rotated position falls into the support of the kernel iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)) vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor return iidx, vals def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): """ Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. """ # compute the support dtheta = (theta_cutoff - 0.0) / ntheta dphi = 2.0 * math.pi / nphi kernel_size = (ntheta-1)*nphi + 1 ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) itheta = ((ikernel - 1) // nphi + 1) * dtheta iphi = ((ikernel - 1) % nphi) * dphi norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) # find the indices where the rotated position falls into the support of the kernel cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) iidx = torch.argwhere(cond_theta & cond_phi) vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0) return iidx, vals def _precompute_convolution_tensor( in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi ): """ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in). The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields $$ Y(\alpha) Z(\beta) Y(\gamma) n = {\begin{bmatrix} \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\ \sin(\beta)\sin(\gamma) \\ \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) \end{bmatrix}} $$ """ assert len(in_shape) == 2 assert len(out_shape) == 2 if len(kernel_shape) == 1: kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) elif len(kernel_shape) == 2: kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) else: raise ValueError("kernel_shape should be either one- or two-dimensional.") nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) lats_in = torch.from_numpy(lats_in).float() lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices out_idx = torch.empty([3, 0], dtype=torch.long) out_vals = torch.empty([0], dtype=torch.long) # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 lons_in = torch.linspace(0, 2*math.pi, nlon_in+1)[:-1] for t in range(nlat_out): # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis alpha = - lats_out[t] beta = lons_in gamma = lats_in.reshape(-1, 1) # compute cartesian coordinates of the rotated position # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, # and therefore applied with a negative sign z = - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) y = torch.sin(beta) * torch.sin(gamma) # normalization is emportant to avoid NaNs when arccos and atan are applied # this can otherwise lead to spurious artifacts in the solution norm = torch.sqrt(x*x + y*y + z*z) x = x / norm y = y / norm z = z / norm # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range theta = torch.arccos(z) phi = torch.arctan2(y, x) + torch.pi # find the indices where the rotated position falls into the support of the kernel iidx, vals = kernel_handle(theta, phi) # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0) # append indices and values to the COO datastructure out_idx = torch.cat([out_idx, idx], dim=-1) out_vals = torch.cat([out_vals, vals], dim=-1) return out_idx, out_vals # TODO: # - derive conv and conv transpose from single module class DiscreteContinuousConvS2(nn.Module): """ Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 """ def __init__( self, in_channels: int, out_channels: int, in_shape: Tuple[int], out_shape: Tuple[int], kernel_shape: Union[int, List[int]], groups: Optional[int] = 1, grid_in: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular", bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, ): super().__init__() self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] if len(kernel_shape) == 1: self.kernel_size = kernel_shape[0] elif len(kernel_shape) == 2: self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") # compute theta cutoff based on the bandlimit of the input field if theta_cutoff is None: theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) if theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") # integration weights _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) idx, vals = _precompute_convolution_tensor( in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff ) psi = torch.sparse_coo_tensor( idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in) ).coalesce() self.register_buffer("psi", psi, persistent=False) # groups self.groups = groups # weight tensor if in_channels % self.groups != 0: raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") if out_channels % self.groups != 0: raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) else: self.bias = None def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # pre-multiply x with the quadrature weights x = self.quad_weights * x if x.is_cuda and use_triton_kernel: x = _disco_s2_contraction_triton(x, self.psi, self.nlon_out) else: x = _disco_s2_contraction_torch(x, self.psi, self.nlon_out) # extract shape B, C, K, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize, K, H, W) # do weight multiplication out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])) out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1]) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) return out class DiscreteContinuousConvTransposeS2(nn.Module): """ Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 """ def __init__( self, in_channels: int, out_channels: int, in_shape: Tuple[int], out_shape: Tuple[int], kernel_shape: Union[int, List[int]], groups: Optional[int] = 1, grid_in: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular", bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, ): super().__init__() self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] if len(kernel_shape) == 1: self.kernel_size = kernel_shape[0] elif len(kernel_shape) == 2: self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") # bandlimit if theta_cutoff is None: theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) if theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") # integration weights _, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) # switch in_shape and out_shape since we want transpose conv idx, vals = _precompute_convolution_tensor( out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff ) psi = torch.sparse_coo_tensor( idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out) ).coalesce() self.register_buffer("psi", psi, persistent=False) # groups self.groups = groups # weight tensor if in_channels % self.groups != 0: raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") if out_channels % self.groups != 0: raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) else: self.bias = None def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # extract shape B, C, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize, H, W) # do weight multiplication x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])) x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1]) # pre-multiply x with the quadrature weights x = self.quad_weights * x if x.is_cuda and use_triton_kernel: out = _disco_s2_transpose_contraction_triton(x, self.psi, self.nlon_out) else: out = _disco_s2_transpose_contraction_torch(x, self.psi, self.nlon_out) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1) return out