Commit 652c4ab2 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

Adding FilterBasis datastructure to contain information about the type of...

Adding FilterBasis datastructure to contain information about the type of basis used in disco convolution
parent 9dc07e9b
......@@ -6,6 +6,8 @@
* Changing default grid in all SHT routines to `equiangular`
* Hotfix to the numpy version requirements
* Reworked DISCO filter basis datastructure
* Support for new filter basis types
### v0.7.2
......
......@@ -43,7 +43,7 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
from torch_harmonics.filter_basis import get_filter_basis
# import custom C++/CUDA extensions if available
try:
......@@ -56,90 +56,6 @@ except ImportError as err:
_cuda_extension_available = False
def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
# compute the support
if nr % 2 == 1:
ir = ikernel * dr
else:
ir = (ikernel + 0.5) * dr
# find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
return iidx, vals
def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case
when there is an uneven number of collocation points across the diameter of the kernel.
"""
kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi
# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi
# find the indices where the rotated position falls into the support of the kernel
if nr % 2 == 1:
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi)
# compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = 1 - dist_r / dr
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
1.0,
)
else:
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
vals += valsn
return iidx, vals
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
......@@ -189,8 +105,7 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
def _precompute_convolution_tensor_s2(
in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
......@@ -216,14 +131,7 @@ def _precompute_convolution_tensor_s2(
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -264,7 +172,7 @@ def _precompute_convolution_tensor_s2(
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
......@@ -303,13 +211,10 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
):
super().__init__()
if isinstance(kernel_shape, int):
self.kernel_shape = [kernel_shape]
else:
self.kernel_shape = kernel_shape
self.kernel_shape = kernel_shape
# get the total number of filters
self.kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type="piecewise linear")
# get the filter basis functions
self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type="piecewise linear")
# groups
self.groups = groups
......@@ -328,6 +233,10 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
else:
self.bias = None
@property
def kernel_size(self):
return self.filter_basis.kernel_size
@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError
......@@ -366,7 +275,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
raise ValueError("Error, theta_cutoff has to be positive.")
idx, vals = _precompute_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
)
# sort the values
......@@ -390,7 +299,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
def psi_idx(self):
......@@ -460,7 +369,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
)
# sort the values
......@@ -484,7 +393,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
def psi_idx(self):
......
......@@ -44,10 +44,8 @@ from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics._filter_basis import compute_kernel_size
from torch_harmonics.filter_basis import get_filter_basis
from torch_harmonics.convolution import (
_compute_support_vals_isotropic,
_compute_support_vals_anisotropic,
_normalize_convolution_tensor_s2,
DiscreteContinuousConv,
)
......@@ -73,8 +71,7 @@ except ImportError as err:
def _precompute_distributed_convolution_tensor_s2(
in_shape,
out_shape,
kernel_shape,
basis_type="piecewise linear",
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
......@@ -100,14 +97,7 @@ def _precompute_distributed_convolution_tensor_s2(
assert len(in_shape) == 2
assert len(out_shape) == 2
kernel_size = compute_kernel_size(kernel_shape=kernel_shape, basis_type=basis_type)
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
kernel_size = filter_basis.kernel_size
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
......@@ -148,7 +138,7 @@ def _precompute_distributed_convolution_tensor_s2(
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi)
iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)
......@@ -243,8 +233,9 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
# set local shapes according to distributed mode:
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
)
# sort the values
......@@ -267,7 +258,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
def psi_idx(self):
......@@ -376,7 +367,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
)
# sort the values
......@@ -399,7 +390,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property
def psi_idx(self):
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import abc
from typing import List, Tuple, Union, Optional
import math
import torch
def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis_type: str):
"""Factory function to generate the appropriate filter basis"""
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
else:
raise ValueError(f"Unknown basis_type {basis_type}")
class AbstractFilterBasis(metaclass=abc.ABCMeta):
"""
Abstract base class for a filter basis
"""
def __init__(
self,
kernel_shape: Union[int, List[int], Tuple[int, int]],
):
self.kernel_shape = kernel_shape
@property
@abc.abstractmethod
def kernel_size(self):
raise NotImplementedError
@abc.abstractmethod
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the kernel's support and returns both indices and values. This routine is designed for sparse evaluations of the filter basis
"""
raise NotImplementedError
class PiecewiseLinearFilterBasis(AbstractFilterBasis):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
def __init__(
self,
kernel_shape: Union[int, List[int], Tuple[int, int]],
):
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
kernel_shape = [kernel_shape[0], 1]
elif len(kernel_shape) != 2:
raise ValueError(f"expected kernel_shape to be a list or tuple of length 1 or 2 buu got {kernel_shape} instead.")
super().__init__(kernel_shape=kernel_shape)
@property
def kernel_size(self):
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
# collocation points
nr = self.kernel_shape[0]
dr = 2 * r_cutoff / (nr + 1)
# compute the support
if nr % 2 == 1:
ir = ikernel * dr
else:
ir = (ikernel + 0.5) * dr
# find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
return iidx, vals
def _compute_support_vals_anisotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
# collocation points
nr = self.kernel_shape[0]
nphi = self.kernel_shape[1]
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi
# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi
# find the indices where the rotated position falls into the support of the kernel
if nr % 2 == 1:
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi)
# compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = 1 - dist_r / dr
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
1.0,
)
else:
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = -r
phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
vals += valsn
return iidx, vals
def compute_support_vals(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
if self.kernel_shape[1] > 1:
return self._compute_support_vals_anisotropic(r, phi, r_cutoff=r_cutoff)
else:
return self._compute_support_vals_isotropic(r, phi, r_cutoff=r_cutoff)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment