utils.py 5.44 KB
Newer Older
1
import numpy as np
2
import torch
3

czkkkkkk's avatar
czkkkkkk committed
4
from dgl.sparse import diag, from_csc, from_csr, SparseMatrix, spmatrix
5

6
7
8
np.random.seed(42)
torch.random.manual_seed(42)

9
10
11
12
13
14
15

def clone_detach_and_grad(t):
    t = t.clone().detach()
    t.requires_grad_()
    return t


16
17
18
19
20
21
22
23
24
25
def rand_stride(t):
    """Add stride to the last dimension of a tensor."""
    stride = np.random.randint(2, 4)
    ret = torch.stack([t] * stride, dim=-1)[..., 0]
    ret = ret.detach()
    if torch.is_floating_point(t):
        ret.requires_grad_()
    return ret


26
def rand_coo(shape, nnz, dev, nz_dim=None):
27
28
29
30
31
    # Create a sparse matrix without duplicate entries.
    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
    nnzid = torch.tensor(nnzid, device=dev).long()
    row = torch.div(nnzid, shape[1], rounding_mode="floor")
    col = nnzid % shape[1]
32
33
34
35
    if nz_dim is None:
        val = torch.randn(nnz, device=dev, requires_grad=True)
    else:
        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
36
37
38
39
    indices = torch.stack([row, col])
    indices = rand_stride(indices)
    val = rand_stride(val)
    return spmatrix(indices, val, shape)
40
41


42
def rand_csr(shape, nnz, dev, nz_dim=None):
43
44
45
46
47
    # Create a sparse matrix without duplicate entries.
    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
    nnzid = torch.tensor(nnzid, device=dev).long()
    row = torch.div(nnzid, shape[1], rounding_mode="floor")
    col = nnzid % shape[1]
48
49
50
51
    if nz_dim is None:
        val = torch.randn(nnz, device=dev, requires_grad=True)
    else:
        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
52
53
54
55
    indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
    for r in row.tolist():
        indptr[r + 1] += 1
    indptr = torch.cumsum(indptr, 0)
56
57
    row_sorted, row_sorted_idx = torch.sort(row)
    indices = col[row_sorted_idx]
58
59
60
    indptr = rand_stride(indptr)
    indices = rand_stride(indices)
    val = rand_stride(val)
61
    return from_csr(indptr, indices, val, shape=shape)
62
63


64
def rand_csc(shape, nnz, dev, nz_dim=None):
65
66
67
68
69
    # Create a sparse matrix without duplicate entries.
    nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
    nnzid = torch.tensor(nnzid, device=dev).long()
    row = torch.div(nnzid, shape[1], rounding_mode="floor")
    col = nnzid % shape[1]
70
71
72
73
    if nz_dim is None:
        val = torch.randn(nnz, device=dev, requires_grad=True)
    else:
        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
74
75
76
77
78
79
    indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
    for c in col.tolist():
        indptr[c + 1] += 1
    indptr = torch.cumsum(indptr, 0)
    col_sorted, col_sorted_idx = torch.sort(col)
    indices = row[col_sorted_idx]
80
81
82
    indptr = rand_stride(indptr)
    indices = rand_stride(indices)
    val = rand_stride(val)
83
    return from_csc(indptr, indices, val, shape=shape)
84
85


czkkkkkk's avatar
czkkkkkk committed
86
87
88
89
90
91
92
93
94
def rand_diag(shape, nnz, dev, nz_dim=None):
    nnz = min(shape)
    if nz_dim is None:
        val = torch.randn(nnz, device=dev, requires_grad=True)
    else:
        val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
    return diag(val, shape)


95
96
97
98
99
def rand_coo_uncoalesced(shape, nnz, dev):
    # Create a sparse matrix with possible duplicate entries.
    row = torch.randint(shape[0], (nnz,), device=dev)
    col = torch.randint(shape[1], (nnz,), device=dev)
    val = torch.randn(nnz, device=dev, requires_grad=True)
100
101
102
    indices = torch.stack([row, col])
    indices = rand_stride(indices)
    return spmatrix(indices, val, shape)
103
104
105
106
107
108
109
110
111
112
113
114
115


def rand_csr_uncoalesced(shape, nnz, dev):
    # Create a sparse matrix with possible duplicate entries.
    row = torch.randint(shape[0], (nnz,), device=dev)
    col = torch.randint(shape[1], (nnz,), device=dev)
    val = torch.randn(nnz, device=dev, requires_grad=True)
    indptr = torch.zeros(shape[0] + 1, device=dev, dtype=torch.int64)
    for r in row.tolist():
        indptr[r + 1] += 1
    indptr = torch.cumsum(indptr, 0)
    row_sorted, row_sorted_idx = torch.sort(row)
    indices = col[row_sorted_idx]
116
117
118
    indptr = rand_stride(indptr)
    indices = rand_stride(indices)
    val = rand_stride(val)
119
    return from_csr(indptr, indices, val, shape=shape)
120
121
122
123
124
125


def rand_csc_uncoalesced(shape, nnz, dev):
    # Create a sparse matrix with possible duplicate entries.
    row = torch.randint(shape[0], (nnz,), device=dev)
    col = torch.randint(shape[1], (nnz,), device=dev)
126
127
128
129
130
    val = torch.randn(nnz, device=dev, requires_grad=True)
    indptr = torch.zeros(shape[1] + 1, device=dev, dtype=torch.int64)
    for c in col.tolist():
        indptr[c + 1] += 1
    indptr = torch.cumsum(indptr, 0)
131
132
    col_sorted, col_sorted_idx = torch.sort(col)
    indices = row[col_sorted_idx]
133
134
135
    indptr = rand_stride(indptr)
    indices = rand_stride(indices)
    val = rand_stride(val)
136
    return from_csc(indptr, indices, val, shape=shape)
137
138
139


def sparse_matrix_to_dense(A: SparseMatrix):
140
    dense = A.to_dense()
141
    return clone_detach_and_grad(dense)
142
143


144
def sparse_matrix_to_torch_sparse(A: SparseMatrix, val=None):
145
146
147
    row, col = A.coo()
    edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
    shape = A.shape
148
149
150
    if val is None:
        val = A.val
    val = val.clone().detach()
151
152
153
154
155
    if len(A.val.shape) > 1:
        shape += (A.val.shape[-1],)
    ret = torch.sparse_coo_tensor(edge_index, val, shape).coalesce()
    ret.requires_grad_()
    return ret
156
157
158
159
160
161
162
163


def dense_mask(dense, sparse):
    ret = torch.zeros_like(dense)
    row, col = sparse.coo()
    for r, c in zip(row, col):
        ret[r, c] = dense[r, c]
    return ret