Unverified Commit ab44ba59 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/cleanup (#90)

* removing duplicate code from distributed convoloution

* replacing from_numpy with as_tensor

* make preprocess_psi_tensor GPU ready.
parent bd92cdf7
......@@ -429,7 +429,7 @@ def main(
# print dataset info
img_size = dataset.input_shape[1:]
class_histogram = torch.from_numpy(dataset.class_histogram)
class_histogram = torch.as_tensor(dataset.class_histogram)
# various class weights where tried such as inverse frequency
# No class weights seem to work best
......
......@@ -263,7 +263,7 @@ class NeighborhoodAttentionS2(nn.Module):
fb = get_filter_basis(kernel_shape=1, basis_type="zernike")
# precompute the neighborhood sparsity pattern
idx, vals = _precompute_convolution_tensor_s2(
idx, _, roff = _precompute_convolution_tensor_s2(
in_shape,
out_shape,
fb,
......@@ -278,26 +278,13 @@ class NeighborhoodAttentionS2(nn.Module):
# this is kept for legacy resons in case we want to resuse sorting of these entries
row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous()
roff_idx = roff.contiguous()
# compute row offsets for more structured traversal.
# only works if rows are sorted but they are by construction
row_offset = np.empty(self.nlat_out + 1, dtype=np.int64)
row_offset[0] = 0
row = row_idx[0]
for idz, z in enumerate(range(col_idx.shape[0])):
if row_idx[z] != row:
row_offset[row + 1] = idz
row = row_idx[z]
# set the last value
row_offset[row + 1] = idz + 1
row_offset = torch.from_numpy(row_offset).contiguous()
# store some metadata
self.max_psi_nnz = col_idx.max().item() + 1
self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_roff_idx", row_offset, persistent=False)
# self.register_buffer("psi_vals", vals, persistent=False)
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# learnable parameters
# TODO: double-check that this gives us the correct initialization magnitudes
......
......@@ -67,6 +67,10 @@ def _normalize_convolution_tensor_s2(
- "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
"""
# exit here if no normalization is needed
if basis_norm_mode == "none":
return psi_vals
# reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
......@@ -192,18 +196,29 @@ def _precompute_convolution_tensor_s2(
out_idx = []
out_vals = []
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute trigs
cbeta = torch.cos(beta)
sbeta = torch.sin(beta)
cgamma = torch.cos(gamma)
sgamma = torch.sin(gamma)
# compute row offsets
out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64)
out_roff[0] = 0
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = -lats_out[t]
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
y = sbeta * sgamma
z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
# normalization is important to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
......@@ -223,9 +238,10 @@ def _precompute_convolution_tensor_s2(
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
# append indices and values to the COO datastructure
# append indices and values to the COO datastructure, compute row offsets
out_idx.append(idx)
out_vals.append(vals)
out_roff[t + 1] = out_roff[t] + iidx.shape[0]
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1)
......@@ -246,7 +262,7 @@ def _precompute_convolution_tensor_s2(
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals
return out_idx, out_vals, out_roff
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
......@@ -333,7 +349,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
idx, vals = _precompute_convolution_tensor_s2(
idx, vals, _ = _precompute_convolution_tensor_s2(
in_shape,
out_shape,
self.filter_basis,
......@@ -353,7 +369,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
if _cuda_extension_available:
# preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
......@@ -438,7 +454,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
raise ValueError("Error, theta_cutoff has to be positive.")
# switch in_shape and out_shape since we want the transpose convolution
idx, vals = _precompute_convolution_tensor_s2(
idx, vals, _ = _precompute_convolution_tensor_s2(
out_shape,
in_shape,
self.filter_basis,
......@@ -458,7 +474,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if _cuda_extension_available:
# preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
......
......@@ -46,7 +46,7 @@ from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_t
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis
from torch_harmonics.convolution import (
_normalize_convolution_tensor_s2,
_precompute_convolution_tensor_s2,
DiscreteContinuousConv,
)
......@@ -68,115 +68,15 @@ except ImportError as err:
_cuda_extension_available = False
def _precompute_distributed_convolution_tensor_s2(
in_shape,
out_shape,
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
def _split_distributed_convolution_tensor_s2(
idx: torch.Tensor,
vals: torch.Tensor,
in_shape: Tuple[int],
out_shape: Tuple[int],
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = _precompute_longitudes(nlon_in)
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
out_idx = []
out_vals = []
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = -lats_out[t]
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# normalization is important to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x)
phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
# append indices and values to the COO datastructure
out_idx.append(idx)
out_vals.append(vals)
# concatenate the indices and values
out_idx = torch.cat(out_idx, dim=-1)
out_vals = torch.cat(out_vals, dim=-1)
out_vals = _normalize_convolution_tensor_s2(
out_idx,
out_vals,
in_shape,
out_shape,
kernel_size,
quad_weights,
transpose_normalization=transpose_normalization,
basis_norm_mode=basis_norm_mode,
merge_quadrature=merge_quadrature,
)
# TODO: this part can be split off into it's own function
# split the latitude indices:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
......@@ -185,17 +85,18 @@ def _precompute_distributed_convolution_tensor_s2(
end_idx = offsets[comm_rank_polar + 1]
# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
lons = out_idx[2] % nlon_in
lats = idx[2] // nlon_in
lons = idx[2] % nlon_in
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats]
vals = vals[ilats]
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
idx = torch.stack([idx[0, ilats], idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
out_idx = out_idx.contiguous()
out_vals = out_vals.to(dtype=torch.float32).contiguous()
# make results contiguous
idx = idx.contiguous()
vals = vals.to(dtype=torch.float32).contiguous()
return out_idx, out_vals
return idx, vals
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
......@@ -254,7 +155,8 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(
# compute global convolution tensor
idx, vals, _ = _precompute_convolution_tensor_s2(
in_shape,
out_shape,
self.filter_basis,
......@@ -266,6 +168,9 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
merge_quadrature=True,
)
# split the convolution tensor along latitude
idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape, out_shape)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
......@@ -343,6 +248,8 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
We assume the data can be splitted in polar and azimuthal directions.
"""
def __init__(
......@@ -392,9 +299,10 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.nlat_in_local = self.nlat_in
self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
# compute global convolution tensor
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(
idx, vals, _ = _precompute_convolution_tensor_s2(
out_shape,
in_shape,
self.filter_basis,
......@@ -406,6 +314,10 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
merge_quadrature=True,
)
# split the convolution tensor along latitude, again, we need to swap the meaning
# of in_shape and out_shape
idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape, in_shape)
# sort the values
ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous()
......
......@@ -80,7 +80,8 @@ class FilterBasis(metaclass=abc.ABCMeta):
@abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values. This routine is designed for sparse evaluations of the filter basis
Computes the index set that falls into the kernel's support and returns both indices and values.
This routine is designed for sparse evaluations of the filter basis.
"""
raise NotImplementedError
......@@ -128,7 +129,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
# collocation points
nr = self.kernel_shape[0]
......@@ -148,11 +149,11 @@ class PiecewiseLinearFilterBasis(FilterBasis):
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
# collocation points
nr = self.kernel_shape[0]
......@@ -179,12 +180,8 @@ class PiecewiseLinearFilterBasis(FilterBasis):
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
# compute the value of the basis functions
vals = 1 - dist_r / dr
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - dist_phi / dphi),
1.0,
)
vals = 1 - dist_r / dr
vals *= torch.where((iidx[:, 0] > 0), (1 - dist_phi / dphi), 1.0)
else:
# in the even case, the inner basis functions overlap into areas with a negative areas
......@@ -197,6 +194,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
cond_phin = _circle_dist(phin, iphi) <= dphi
# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = _circle_dist(phi[iidx[:, 1], iidx[:, 2]], iphi[iidx[:, 0], 0, 0])
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
......@@ -251,7 +249,7 @@ class MorletFilterBasis(FilterBasis):
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
nkernel = ikernel % self.kernel_shape[1]
mkernel = ikernel // self.kernel_shape[1]
......@@ -270,7 +268,6 @@ class MorletFilterBasis(FilterBasis):
harmonic *= torch.where(m % 2 == 1, torch.sin(torch.ceil(m / 2) * math.pi * y / width), torch.cos(torch.ceil(m / 2) * math.pi * y / width))
# computes the envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
# vals = self.gaussian_window(r, width=width) * harmonic
vals = self.hann_window(r, width=width) * harmonic
return iidx, vals
......@@ -327,7 +324,7 @@ class ZernikeFilterBasis(FilterBasis):
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
# precompute shifts in the level of the "pyramid"
nshifts = torch.arange(self.kernel_shape)
nshifts = torch.arange(self.kernel_shape, device=r.device)
nshifts = (nshifts + 1) * nshifts // 2
# find the level and position within the pyramid
nkernel = torch.searchsorted(nshifts, ikernel, right=True) - 1
......
......@@ -88,7 +88,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
on the interval [a, b]
"""
xlg = torch.from_numpy(np.linspace(a, b, n, endpoint=periodic))
xlg = torch.as_tensor(np.linspace(a, b, n, endpoint=periodic))
wlg = (b - a) / (n - periodic * 1) * torch.ones(n, requires_grad=False)
if not periodic:
......@@ -105,8 +105,8 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
"""
xlg, wlg = np.polynomial.legendre.leggauss(n)
xlg = torch.from_numpy(xlg).clone()
wlg = torch.from_numpy(wlg).clone()
xlg = torch.as_tensor(xlg).clone()
wlg = torch.as_tensor(wlg).clone()
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment