Commit 0c067c86 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

intermediate release with reworked normalization of S2 convolutions

parent 34927a33
......@@ -8,8 +8,10 @@
* Hotfix to the numpy version requirements
* Changing default grid in all SHT routines to `equiangular`, which makes it consistent with DISCO convolutions
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* New filter basis normalization in DISCO convolutions
* Reworked DISCO filter basis datastructure
* Support for new filter basis types
* Adding Morlet-like basis functions on a spherical disk
### v0.7.2
......
......@@ -56,9 +56,11 @@ except ImportError as err:
_cuda_extension_available = False
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="sum", merge_quadrature=False, eps=1e-9
):
"""
Discretely normalizes the convolution tensor.
Discretely normalizes the convolution tensor. Supports different normalization modes
"""
nlat_in, nlon_in = in_shape
......@@ -74,10 +76,21 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
# loop through dimensions which require normalization
for ik in range(kernel_size):
for ilat in range(nlat_in):
# get relevant entries
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx] * q[iidx])
# get relevant entries depending on the normalization mode
if basis_norm_mode == "individual":
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
elif basis_norm_mode == "sum":
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[2] == ilat)
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
# the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
......@@ -90,10 +103,20 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
# loop through dimensions which require normalization
for ik in range(kernel_size):
for ilat in range(nlat_out):
# get relevant entries
iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
# normalize
vnorm = torch.sum(psi_vals[iidx] * q[iidx])
# get relevant entries depending on the normalization mode
if basis_norm_mode == "individual":
iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
# normalize
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
elif basis_norm_mode == "sum":
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx = torch.argwhere(idx[1] == ilat)
# normalize
vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
else:
raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
else:
......@@ -110,6 +133,7 @@ def _precompute_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="sum",
merge_quadrature=False,
):
"""
......@@ -136,6 +160,7 @@ def _precompute_convolution_tensor_s2(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
......@@ -145,6 +170,12 @@ def _precompute_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_idx = []
out_vals = []
for t in range(nlat_out):
......@@ -185,12 +216,16 @@ def _precompute_convolution_tensor_s2(
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
out_idx,
out_vals,
in_shape,
out_shape,
kernel_size,
quad_weights,
transpose_normalization=transpose_normalization,
basis_norm_mode=basis_norm_mode,
merge_quadrature=merge_quadrature,
)
return out_idx, out_vals
......@@ -198,7 +233,7 @@ def _precompute_convolution_tensor_s2(
class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for DISCO convolutions
Abstract base class for discrete-continuous convolutions
"""
def __init__(
......@@ -245,7 +280,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -258,6 +293,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -277,7 +313,15 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
raise ValueError("Error, theta_cutoff has to be positive.")
idx, vals = _precompute_convolution_tensor_s2(
in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
in_shape,
out_shape,
self.filter_basis,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
transpose_normalization=False,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
)
# sort the values
......@@ -339,7 +383,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
......@@ -352,6 +396,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -372,7 +417,15 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor_s2(
out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
out_shape,
in_shape,
self.filter_basis,
grid_in=grid_out,
grid_out=grid_in,
theta_cutoff=theta_cutoff,
transpose_normalization=True,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
)
# sort the values
......
......@@ -76,6 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="sum",
merge_quadrature=False,
):
"""
......@@ -111,6 +112,12 @@ def _precompute_distributed_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_idx = []
out_vals = []
for t in range(nlat_out):
......@@ -151,13 +158,16 @@ def _precompute_distributed_convolution_tensor_s2(
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
# perform the normalization over the entire psi matrix
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
out_idx,
out_vals,
in_shape,
out_shape,
kernel_size,
quad_weights,
transpose_normalization=transpose_normalization,
basis_norm_mode=basis_norm_mode,
merge_quadrature=merge_quadrature,
)
# TODO: this part can be split off into it's own function
......@@ -197,6 +207,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -236,7 +247,15 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
in_shape,
out_shape,
self.filter_basis,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
transpose_normalization=False,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
)
# sort the values
......@@ -328,6 +347,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "sum",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -369,7 +389,15 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(
out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
out_shape,
in_shape,
self.filter_basis,
grid_in=grid_out,
grid_out=grid_in,
theta_cutoff=theta_cutoff,
transpose_normalization=True,
basis_norm_mode=basis_norm_mode,
merge_quadrature=True,
)
# sort the values
......
......@@ -69,7 +69,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=4*math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
)
def forward(self, x):
......@@ -115,7 +115,7 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=4*math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
......
......@@ -47,7 +47,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
raise ValueError(f"Unknown basis_type {basis_type}")
class AbstractFilterBasis(metaclass=abc.ABCMeta):
class FilterBasis(metaclass=abc.ABCMeta):
"""
Abstract base class for a filter basis
"""
......@@ -72,7 +72,7 @@ class AbstractFilterBasis(metaclass=abc.ABCMeta):
raise NotImplementedError
class PiecewiseLinearFilterBasis(AbstractFilterBasis):
class PiecewiseLinearFilterBasis(FilterBasis):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
......@@ -190,7 +190,7 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
else:
return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff)
class DiskMorletFilterBasis(AbstractFilterBasis):
class DiskMorletFilterBasis(FilterBasis):
"""
Morlet-like Filter basis. A Gaussian is multiplied with a Fourier basis in x and y.
"""
......@@ -228,7 +228,8 @@ class DiskMorletFilterBasis(AbstractFilterBasis):
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
# # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
width = 0.01
# width = 0.01
width = 0.25
# width = 1.0
# envelope = self._gaussian_envelope(r, width=0.25 * r_cutoff)
......@@ -245,7 +246,7 @@ class DiskMorletFilterBasis(AbstractFilterBasis):
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals = self._gaussian_envelope(r[iidx[:, 1], iidx[:, 2]] / r_cutoff, width=width) * harmonic[iidx[:, 0], iidx[:, 1], iidx[:, 2]] / disk_area
vals = torch.ones_like(vals)
# vals = torch.ones_like(vals)
return iidx, vals
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment