import torch
import torch_scatter.scatter_cpu
if torch.cuda.is_available():
import torch_scatter.scatter_cuda
def get_func(name, tensor):
if tensor.is_cuda: module = torch_scatter.scatter_cuda else: module = torch_scatter.scatter_cpu
return getattr(module, name)