Unverified Commit 54502a17 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco refactor (#29)

* Cleaned up DISCO convolutions
parent c971d458
......@@ -4,10 +4,11 @@
### v0.6.5
* Discrrete-continuous (DISCO) convolutions on the sphere
* Isotropic and anisotropic DISCO convolutions
* Accelerated DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions
* Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
* DISCO supports isotropic and anisotropic kernel functions parameterized as hat functions
* Supports regular and transpose convolutions
* Accelerated spherical DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions in `tests/test_convolution.py`
### v0.6.4
......
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = '0.6.4'
__version__ = '0.6.5'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
......@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import abc
from typing import List, Tuple, Union, Optional
import math
......@@ -38,7 +39,7 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch,
......@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import (
)
def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
dr = (r_cutoff - 0.0) / nr
ikernel = torch.arange(nr).reshape(-1, 1, 1)
ir = ikernel * dr
if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
elif norm == "s2":
norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff))
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
return iidx, vals
def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dr = (r_cutoff - 0.0) / nr
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
kernel_size = (nr - 1) * nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
elif norm == "s2":
norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_theta & cond_phi)
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0)
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_r & cond_phi)
vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
vals *= torch.where(
iidx[:, 0] > 0,
(1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi),
1.0,
)
return iidx, vals
def _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
......@@ -111,9 +129,9 @@ def _precompute_convolution_tensor(
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2")
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2")
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
......@@ -131,24 +149,24 @@ def _precompute_convolution_tensor(
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2*math.pi, nlon_in+1)[:-1]
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = - lats_out[t]
alpha = -lats_out[t]
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z = - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x*x + y*y + z*z)
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
......@@ -170,9 +188,96 @@ def _precompute_convolution_tensor(
return out_idx, out_vals
# TODO:
# - derive conv and conv transpose from single module
class DiscreteContinuousConvS2(nn.Module):
def _precompute_convolution_tensor_2d(grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False):
"""
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""
# check that input arrays are valid point clouds in 2D
assert len(grid_in) == 2
assert len(grid_out) == 2
assert grid_in.shape[0] == 2
assert grid_out.shape[0] == 2
n_in = grid_in.shape[-1]
n_out = grid_out.shape[-1]
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=radius_cutoff, norm="2d")
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=radius_cutoff, norm="2d")
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
grid_in = grid_in.reshape(2, 1, n_in)
grid_out = grid_out.reshape(2, n_out, 1)
diffs = grid_in - grid_out
if periodic:
periodic_diffs = torch.where(diffs > 0.0, diffs-1, diffs+1)
diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)
r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi
idx, vals = kernel_handle(r, phi)
idx = idx.permute(1, 0)
return idx, vals
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for DISCO convolutions
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
bias: Optional[bool] = True,
):
super().__init__()
if isinstance(kernel_shape, int):
self.kernel_shape = [kernel_shape]
else:
self.kernel_shape = kernel_shape
if len(self.kernel_shape) == 1:
self.kernel_size = self.kernel_shape[0]
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError
class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
......@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module):
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__()
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module):
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)
idx, vals = _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
......@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module):
return out
class DiscreteContinuousConvTransposeS2(nn.Module):
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
......@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__()
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# bandlimit
if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
self.register_buffer("quad_weights", quad_weights, persistent=False)
# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape
......@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
......@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
out = out + self.bias.reshape(1, -1, 1, 1)
return out
......@@ -31,26 +31,53 @@
import numpy as np
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
# compute coordinates
if grid == "equidistant":
xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic)
elif grid == "legendre-gauss":
xlg, wlg = legendre_gauss_weights(n, a=a, b=b)
elif grid == "lobatto":
xlg, wlg = lobatto_weights(n, a=a, b=b)
elif grid == "equiangular":
xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b)
else:
raise ValueError(f"Unknown grid type {grid}")
return xlg, wlg
def _precompute_latitudes(nlat, grid="equiangular"):
r"""
Convenience routine to precompute latitudes
"""
# compute coordinates
if grid == "legendre-gauss":
xlg, wlg = legendre_gauss_weights(nlat)
elif grid == "lobatto":
xlg, wlg = lobatto_weights(nlat)
elif grid == "equiangular":
xlg, wlg = clenshaw_curtiss_weights(nlat)
else:
raise ValueError("Unknown grid")
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
lats = np.flip(np.arccos(xlg)).copy()
wlg = np.flip(wlg).copy()
return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
r"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg = np.linspace(a, b, n)
wlg = (b - a) / (n - 1) * np.ones(n)
if not periodic:
wlg[0] *= 0.5
wlg[-1] *= 0.5
return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0):
r"""
Helper routine which returns the Legendre-Gauss nodes and weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment