edgewise_spline_weighting.py 441 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
import torch

if torch.cuda.is_available():
    from .edgewise_spline_weighting_gpu import EdgewiseSplineWeightingGPU


7
def edgewise_spline_weighting(input, weight, amount, index, k_fw, k_bw):
rusty1s's avatar
rename  
rusty1s committed
8
    if input.is_cuda:
9
        K, M_in, M_out = weight.size()
10
        return EdgewiseSplineWeightingGPU(amount, index, K, M_in, M_out
11
                                          , k_fw, k_bw)(input, weight)
rusty1s's avatar
rename  
rusty1s committed
12
    else:
rusty1s's avatar
rusty1s committed
13
        raise NotImplementedError