edgewise_spline_weighting.py 319 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch

if torch.cuda.is_available():
    from .edgewise_spline_weighting_gpu import EdgewiseSplineWeightingGPU


def edgewise_spline_weighting(input, weight, amount, index):
    if input.is_cuda:
        return EdgewiseSplineWeightingGPU(amount, index)(input, weight)
    else:
rusty1s's avatar
rusty1s committed
11
        raise NotImplementedError