utils.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from dgl.mock_sparse2 import (
    create_from_coo,
    create_from_csc,
    create_from_csr,
    SparseMatrix,
)


def clone_detach_and_grad(t):
    t = t.clone().detach()
    t.requires_grad_()
    return t


def rand_coo(shape, nnz, dev):
    row = torch.randint(0, shape[0], (nnz,), device=dev)
    col = torch.randint(0, shape[1], (nnz,), device=dev)
    val = torch.randn(nnz, device=dev, requires_grad=True)
    return create_from_coo(row, col, val, shape)


def rand_csr(shape, nnz, dev):
    row = torch.randint(0, shape[0], (nnz,), device=dev)
    col = torch.randint(0, shape[1], (nnz,), device=dev)
    val = torch.randn(nnz, device=dev, requires_grad=True)
    indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
    for r in row.tolist():
        indptr[r + 1] += 1
    indptr = torch.cumsum(indptr, 0)
    indices = col
    return create_from_csr(indptr, indices, val, shape=shape)


def rand_csc(shape, nnz, dev):
    row = torch.randint(0, shape[0], (nnz,), device=dev)
    col = torch.randint(0, shape[1], (nnz,), device=dev)
    val = torch.randn(nnz, device=dev, requires_grad=True)
    indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
    for c in col.tolist():
        indptr[c + 1] += 1
    indptr = torch.cumsum(indptr, 0)
    indices = row
    return create_from_csc(indptr, indices, val, shape=shape)


def sparse_matrix_to_dense(A: SparseMatrix):
    dense = A.dense()
    dense.requires_grad_()
    return dense


def sparse_matrix_to_torch_sparse(A: SparseMatrix):
    row, col = A.coo()
    edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
    shape = A.shape
    val = A.val.clone().detach()
    if len(A.val.shape) > 1:
        shape += (A.val.shape[-1],)
    ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()
    ret.requires_grad_()
    return ret