edgewise_spline_weighting.py 428 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

from .edgewise_spline_weighting_cpu import EdgewiseSplineWeightingCPU

if torch.cuda.is_available():
    from .edgewise_spline_weighting_gpu import EdgewiseSplineWeightingGPU


def edgewise_spline_weighting(input, weight, amount, index):
    if input.is_cuda:
        return EdgewiseSplineWeightingGPU(amount, index)(input, weight)
    else:
        return EdgewiseSplineWeightingCPU(amount, index)(input, weight)