sparse.py 210 Bytes
Newer Older
rusty1s's avatar
to csr  
rusty1s committed
1
2
3
4
5
6
7
import torch


def SparseTensor(index, value, size):
    t = torch.cuda if value.is_cuda else torch
    SparseTensor = getattr(t.sparse, value.type().split('.')[-1])
    return SparseTensor(index, value, size)