utils.py 296 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
import torch as th


def sparse_sp2th(matrix):
    coo = matrix.tocoo()
    rows = th.from_numpy(coo.row).long().view(1, -1)
    cols = th.from_numpy(coo.col).long().view(1, -1)
    data = th.from_numpy(coo.data).float()
    return th.sparse.FloatTensor(th.cat((rows, cols), 0), data, coo.shape)