weighting.py 298 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
7
@torch.jit.script
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
                     basis: torch.Tensor,
                     weight_index: torch.Tensor) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
8
9
    return torch.ops.torch_spline_conv.spline_weighting(
        x, weight, basis, weight_index)