spline.py 762 Bytes
Newer Older
rusty1s's avatar
rename  
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

from .spline_cpu import spline_cpu

if torch.cuda.is_available():
    from .spline_linear_gpu import spline_linear_gpu
    from .spline_quadratic_gpu import spline_quadratic_gpu
    from .spline_cubic_gpu import spline_cubic_gpu


def spline(input, kernel_size, is_open_spline, K, degree):
    if input.is_cuda:
        if degree == 1:
            return spline_linear_gpu(input, kernel_size, is_open_spline, K)
        if degree == 2:
            return spline_quadratic_gpu(input, kernel_size, is_open_spline, K)
        if degree == 3:
            return spline_cubic_gpu(input, kernel_size, is_open_spline, K)
        else:
            raise NotImplementedError()
    else:
        return spline_cpu(input, kernel_size, is_open_spline, degree)