edgewise_spline_weighting.py 412 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
import torch

if torch.cuda.is_available():
    from .edgewise_spline_weighting_gpu import EdgewiseSplineWeightingGPU


def edgewise_spline_weighting(input, weight, amount, index):
    if input.is_cuda:
9
10
11
        K, M_in, M_out = weight.size()
        k_max = amount.size(1)
        return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out, k_max)(input, weight)
rusty1s's avatar
rename  
rusty1s committed
12
    else:
rusty1s's avatar
rusty1s committed
13
        raise NotImplementedError