import torch from torch._tensor_docs import tensor_classes tensors = [t[:-4] for t in tensor_classes] tensors.remove('ShortTensor') # TODO: PyTorch `atomicAdd` bug with short type. def Tensor(str, x): tensor = getattr(torch, str) return tensor(x)