Commit b91f517c authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

factoring out kernel_size computation

parent 837335f8
......@@ -43,11 +43,13 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
# import custom C++/CUDA extensions if available
try:
from disco_helpers import preprocess_psi
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
......@@ -138,7 +140,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in
return iidx, vals
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
"""
......@@ -146,11 +148,6 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
if len(kernel_shape) == 1:
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
# reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)
......@@ -190,7 +187,15 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
def _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
......@@ -211,6 +216,8 @@ def _precompute_convolution_tensor_s2(
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
......@@ -275,7 +282,7 @@ def _precompute_convolution_tensor_s2(
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
)
return out_idx, out_vals
......@@ -301,16 +308,8 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
else:
self.kernel_shape = kernel_shape
if len(self.kernel_shape) == 1:
self.kernel_size = math.ceil(self.kernel_shape[0] / 2)
if self.kernel_shape[0] % 2 == 0:
warn(
"Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior."
)
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
if len(self.kernel_shape) > 2:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# get the total number of filters
self.kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type="piecewise linear")
# groups
self.groups = groups
......
......@@ -44,7 +44,7 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
from torch_harmonics.convolution import (
_compute_support_vals_isotropic,
_compute_support_vals_anisotropic,
......@@ -52,6 +52,7 @@ from torch_harmonics.convolution import (
DiscreteContinuousConv,
)
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
......@@ -62,6 +63,7 @@ from torch_harmonics.distributed import compute_split_shapes, split_tensor_along
try:
from disco_helpers import preprocess_psi
import disco_cuda_extension
_cuda_extension_available = True
except ImportError as err:
disco_cuda_extension = None
......@@ -69,7 +71,15 @@ except ImportError as err:
def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
......@@ -90,6 +100,8 @@ def _precompute_distributed_convolution_tensor_s2(
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
......@@ -154,7 +166,9 @@ def _precompute_distributed_convolution_tensor_s2(
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature)
out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
)
# TODO: this part can be split off into it's own function
# split the latitude indices:
......@@ -163,7 +177,7 @@ def _precompute_distributed_convolution_tensor_s2(
split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
offsets = [0] + list(accumulate(split_shapes))
start_idx = offsets[comm_rank_polar]
end_idx = offsets[comm_rank_polar+1]
end_idx = offsets[comm_rank_polar + 1]
# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
......@@ -171,7 +185,7 @@ def _precompute_distributed_convolution_tensor_s2(
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats]
# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0)
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
return out_idx, out_vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment