ext.py 227 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch
import spline_conv_cpu

if torch.cuda.is_available():
    import spline_conv_cuda


def get_func(name, tensor):
    module = spline_conv_cuda if tensor.is_cuda else spline_conv_cpu
    return getattr(module, name)