Unverified Commit 54502a17 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco refactor (#29)

* Cleaned up DISCO convolutions
parent c971d458
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
### v0.6.5 ### v0.6.5
* Discrrete-continuous (DISCO) convolutions on the sphere * Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
* Isotropic and anisotropic DISCO convolutions * DISCO supports isotropic and anisotropic kernel functions parameterized as hat functions
* Accelerated DISCO convolutions on GPU via Triton implementation * Supports regular and transpose convolutions
* Unittests for DISCO convolutions * Accelerated spherical DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions in `tests/test_convolution.py`
### v0.6.4 ### v0.6.4
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = '0.6.4' __version__ = '0.6.5'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import abc
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional
import math import math
...@@ -38,7 +39,7 @@ import torch.nn as nn ...@@ -38,7 +39,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import ( from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch, _disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch, _disco_s2_transpose_contraction_torch,
...@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import ( ...@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import (
) )
def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"):
""" """
Computes the index set that falls into the isotropic kernel's support and returns both indices and values. Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
""" """
# compute the support # compute the support
dtheta = (theta_cutoff - 0.0) / ntheta dr = (r_cutoff - 0.0) / nr
ikernel = torch.arange(ntheta).reshape(-1, 1, 1) ikernel = torch.arange(nr).reshape(-1, 1, 1)
itheta = ikernel * dtheta ir = ikernel * dr
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
elif norm == "s2":
norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)) iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
return iidx, vals return iidx, vals
def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"):
""" """
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
""" """
# compute the support # compute the support
dtheta = (theta_cutoff - 0.0) / ntheta dr = (r_cutoff - 0.0) / nr
dphi = 2.0 * math.pi / nphi dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1 kernel_size = (nr - 1) * nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
elif norm == "s2":
norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_theta & cond_phi) iidx = torch.argwhere(cond_r & cond_phi)
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0) vals *= torch.where(
iidx[:, 0] > 0,
(1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi),
1.0,
)
return iidx, vals return iidx, vals
def _precompute_convolution_tensor( def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
...@@ -111,9 +129,9 @@ def _precompute_convolution_tensor( ...@@ -111,9 +129,9 @@ def _precompute_convolution_tensor(
assert len(out_shape) == 2 assert len(out_shape) == 2
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2")
elif len(kernel_shape) == 2: elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2")
else: else:
raise ValueError("kernel_shape should be either one- or two-dimensional.") raise ValueError("kernel_shape should be either one- or two-dimensional.")
...@@ -131,24 +149,24 @@ def _precompute_convolution_tensor( ...@@ -131,24 +149,24 @@ def _precompute_convolution_tensor(
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2*math.pi, nlon_in+1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
for t in range(nlat_out): for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = - lats_out[t] alpha = -lats_out[t]
beta = lons_in beta = lons_in
gamma = lats_in.reshape(-1, 1) gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign # and therefore applied with a negative sign
z = - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma) y = torch.sin(beta) * torch.sin(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied # normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution # this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x*x + y*y + z*z) norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm x = x / norm
y = y / norm y = y / norm
z = z / norm z = z / norm
...@@ -170,9 +188,96 @@ def _precompute_convolution_tensor( ...@@ -170,9 +188,96 @@ def _precompute_convolution_tensor(
return out_idx, out_vals return out_idx, out_vals
# TODO: def _precompute_convolution_tensor_2d(grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False):
# - derive conv and conv transpose from single module """
class DiscreteContinuousConvS2(nn.Module): Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""
# check that input arrays are valid point clouds in 2D
assert len(grid_in) == 2
assert len(grid_out) == 2
assert grid_in.shape[0] == 2
assert grid_out.shape[0] == 2
n_in = grid_in.shape[-1]
n_out = grid_out.shape[-1]
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=radius_cutoff, norm="2d")
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=radius_cutoff, norm="2d")
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
grid_in = grid_in.reshape(2, 1, n_in)
grid_out = grid_out.reshape(2, n_out, 1)
diffs = grid_in - grid_out
if periodic:
periodic_diffs = torch.where(diffs > 0.0, diffs-1, diffs+1)
diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)
r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi
idx, vals = kernel_handle(r, phi)
idx = idx.permute(1, 0)
return idx, vals
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for DISCO convolutions
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
bias: Optional[bool] = True,
):
super().__init__()
if isinstance(kernel_shape, int):
self.kernel_shape = [kernel_shape]
else:
self.kernel_shape = kernel_shape
if len(self.kernel_shape) == 1:
self.kernel_size = self.kernel_shape[0]
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError
class DiscreteContinuousConvS2(DiscreteContinuousConv):
""" """
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
...@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module):
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__() super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# compute theta cutoff based on the bandlimit of the input field # compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module):
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False) self.register_buffer("quad_weights", quad_weights, persistent=False)
idx, vals = _precompute_convolution_tensor( idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff)
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False) self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias: def get_psi(self):
self.bias = nn.Parameter(torch.zeros(out_channels)) psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
else: return psi
self.bias = None
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# pre-multiply x with the quadrature weights # pre-multiply x with the quadrature weights
x = self.quad_weights * x x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce() psi = self.get_psi()
if x.is_cuda and use_triton_kernel: if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out) x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
...@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module):
return out return out
class DiscreteContinuousConvTransposeS2(nn.Module): class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
""" """
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
...@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
bias: Optional[bool] = True, bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None, theta_cutoff: Optional[float] = None,
): ):
super().__init__() super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape self.nlat_out, self.nlon_out = out_shape
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
self.register_buffer("quad_weights", quad_weights, persistent=False) self.register_buffer("quad_weights", quad_weights, persistent=False)
# switch in_shape and out_shape since we want transpose conv # switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor( idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff)
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False) self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
# weight tensor def get_psi(self):
if in_channels % self.groups != 0: psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") return psi
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape # extract shape
...@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# pre-multiply x with the quadrature weights # pre-multiply x with the quadrature weights
x = self.quad_weights * x x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() psi = self.get_psi()
if x.is_cuda and use_triton_kernel: if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out) out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
...@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
out = out + self.bias.reshape(1, -1, 1, 1) out = out + self.bias.reshape(1, -1, 1, 1)
return out return out
...@@ -31,26 +31,53 @@ ...@@ -31,26 +31,53 @@
import numpy as np import numpy as np
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
# compute coordinates
if grid == "equidistant":
xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic)
elif grid == "legendre-gauss":
xlg, wlg = legendre_gauss_weights(n, a=a, b=b)
elif grid == "lobatto":
xlg, wlg = lobatto_weights(n, a=a, b=b)
elif grid == "equiangular":
xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b)
else:
raise ValueError(f"Unknown grid type {grid}")
return xlg, wlg
def _precompute_latitudes(nlat, grid="equiangular"): def _precompute_latitudes(nlat, grid="equiangular"):
r""" r"""
Convenience routine to precompute latitudes Convenience routine to precompute latitudes
""" """
# compute coordinates # compute coordinates
if grid == "legendre-gauss": xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
xlg, wlg = legendre_gauss_weights(nlat)
elif grid == "lobatto":
xlg, wlg = lobatto_weights(nlat)
elif grid == "equiangular":
xlg, wlg = clenshaw_curtiss_weights(nlat)
else:
raise ValueError("Unknown grid")
lats = np.flip(np.arccos(xlg)).copy() lats = np.flip(np.arccos(xlg)).copy()
wlg = np.flip(wlg).copy() wlg = np.flip(wlg).copy()
return lats, wlg return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
r"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg = np.linspace(a, b, n)
wlg = (b - a) / (n - 1) * np.ones(n)
if not periodic:
wlg[0] *= 0.5
wlg[-1] *= 0.5
return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0): def legendre_gauss_weights(n, a=-1.0, b=1.0):
r""" r"""
Helper routine which returns the Legendre-Gauss nodes and weights Helper routine which returns the Legendre-Gauss nodes and weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment