Commit b91f517c authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

factoring out kernel_size computation

parent 837335f8
...@@ -43,11 +43,13 @@ from functools import partial ...@@ -43,11 +43,13 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
# import custom C++/CUDA extensions if available # import custom C++/CUDA extensions if available
try: try:
from disco_helpers import preprocess_psi from disco_helpers import preprocess_psi
import disco_cuda_extension import disco_cuda_extension
_cuda_extension_available = True _cuda_extension_available = True
except ImportError as err: except ImportError as err:
disco_cuda_extension = None disco_cuda_extension = None
...@@ -138,7 +140,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in ...@@ -138,7 +140,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in
return iidx, vals return iidx, vals
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
""" """
Discretely normalizes the convolution tensor. Discretely normalizes the convolution tensor.
""" """
...@@ -146,11 +148,6 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker ...@@ -146,11 +148,6 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
if len(kernel_shape) == 1:
kernel_size = math.ceil(kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
# reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0) idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)
...@@ -190,7 +187,15 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker ...@@ -190,7 +187,15 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
def _precompute_convolution_tensor_s2( def _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -211,6 +216,8 @@ def _precompute_convolution_tensor_s2( ...@@ -211,6 +216,8 @@ def _precompute_convolution_tensor_s2(
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2: elif len(kernel_shape) == 2:
...@@ -275,7 +282,7 @@ def _precompute_convolution_tensor_s2( ...@@ -275,7 +282,7 @@ def _precompute_convolution_tensor_s2(
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2( out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
) )
return out_idx, out_vals return out_idx, out_vals
...@@ -301,16 +308,8 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): ...@@ -301,16 +308,8 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
else: else:
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
if len(self.kernel_shape) == 1: # get the total number of filters
self.kernel_size = math.ceil(self.kernel_shape[0] / 2) self.kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type="piecewise linear")
if self.kernel_shape[0] % 2 == 0:
warn(
"Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior."
)
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
if len(self.kernel_shape) > 2:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# groups # groups
self.groups = groups self.groups = groups
......
...@@ -44,7 +44,7 @@ from functools import partial ...@@ -44,7 +44,7 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
from torch_harmonics.convolution import ( from torch_harmonics.convolution import (
_compute_support_vals_isotropic, _compute_support_vals_isotropic,
_compute_support_vals_anisotropic, _compute_support_vals_anisotropic,
...@@ -52,6 +52,7 @@ from torch_harmonics.convolution import ( ...@@ -52,6 +52,7 @@ from torch_harmonics.convolution import (
DiscreteContinuousConv, DiscreteContinuousConv,
) )
from torch_harmonics.distributed import polar_group_size, azimuth_group_size from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
...@@ -62,6 +63,7 @@ from torch_harmonics.distributed import compute_split_shapes, split_tensor_along ...@@ -62,6 +63,7 @@ from torch_harmonics.distributed import compute_split_shapes, split_tensor_along
try: try:
from disco_helpers import preprocess_psi from disco_helpers import preprocess_psi
import disco_cuda_extension import disco_cuda_extension
_cuda_extension_available = True _cuda_extension_available = True
except ImportError as err: except ImportError as err:
disco_cuda_extension = None disco_cuda_extension = None
...@@ -69,7 +71,15 @@ except ImportError as err: ...@@ -69,7 +71,15 @@ except ImportError as err:
def _precompute_distributed_convolution_tensor_s2( def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
merge_quadrature=False,
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -90,6 +100,8 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -90,6 +100,8 @@ def _precompute_distributed_convolution_tensor_s2(
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2: elif len(kernel_shape) == 2:
...@@ -154,7 +166,9 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -154,7 +166,9 @@ def _precompute_distributed_convolution_tensor_s2(
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature) out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
)
# TODO: this part can be split off into it's own function # TODO: this part can be split off into it's own function
# split the latitude indices: # split the latitude indices:
...@@ -163,7 +177,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -163,7 +177,7 @@ def _precompute_distributed_convolution_tensor_s2(
split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar) split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
offsets = [0] + list(accumulate(split_shapes)) offsets = [0] + list(accumulate(split_shapes))
start_idx = offsets[comm_rank_polar] start_idx = offsets[comm_rank_polar]
end_idx = offsets[comm_rank_polar+1] end_idx = offsets[comm_rank_polar + 1]
# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in lats = out_idx[2] // nlon_in
...@@ -171,7 +185,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -171,7 +185,7 @@ def _precompute_distributed_convolution_tensor_s2(
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze() ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats] out_vals = out_vals[ilats]
# for the indices we need to recompute them to refer to local indices of the input tenor # for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0) out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
return out_idx, out_vals return out_idx, out_vals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment