spline.py 445 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
import torch

if torch.cuda.is_available():
4
    from .compute_spline_basis import compute_spline_basis
rusty1s's avatar
rename  
rusty1s committed
5
6
7
8
    from .spline_quadratic_gpu import spline_quadratic_gpu
    from .spline_cubic_gpu import spline_cubic_gpu


9
def spline(input, kernel_size, is_open_spline, K, degree, basis_kernel):
rusty1s's avatar
rename  
rusty1s committed
10
    if input.is_cuda:
11
        return compute_spline_basis(input, kernel_size, is_open_spline, K, basis_kernel)
rusty1s's avatar
rename  
rusty1s committed
12
    else:
rusty1s's avatar
rusty1s committed
13
        raise NotImplementedError()