weighting.py 333 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
7
8
9
@torch.jit.script
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
                     basis: torch.Tensor,
                     weight_index: torch.Tensor) -> torch.Tensor:
    return torch.ops.spline_conv.spline_weighting(x, weight, basis,
                                                  weight_index)