ext.py 298 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
import torch_scatter.scatter_cpu
rusty1s's avatar
rusty1s committed
3
4

if torch.cuda.is_available():
rusty1s's avatar
rusty1s committed
5
    import torch_scatter.scatter_cuda
rusty1s's avatar
rusty1s committed
6
7
8


def get_func(name, tensor):
rusty1s's avatar
rusty1s committed
9
10
11
12
    if tensor.is_cuda:
        module = torch_scatter.scatter_cuda
    else:
        module = torch_scatter.scatter_cpu
Matthias Fey's avatar
Matthias Fey committed
13
    return getattr(module, name)