Commit b3746aab authored by rusty1s's avatar rusty1s
Browse files

sorting

parent 3307f0d9
......@@ -12,8 +12,17 @@ class SparseStorage(object):
assert row.device == row.device
assert row.dim() == 1 and col.dim() == 1 and row.numel() == col.numel()
if sparse_size is None:
sparse_size = Size((row.max().item() + 1, col.max().item() + 1))
if not is_sorted:
# Sort row and col
idx = sparse_size[1] * row + col
# Only sort if necessary...
if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
perm = idx.argsort()
row = row[perm]
col = col[perm]
value = None if value is None else value[perm]
rowptr = None
colptr = None
arg_csr_to_csc = None
......@@ -23,17 +32,16 @@ class SparseStorage(object):
assert row.device == value.device and value.size(0) == row.size(0)
value = value.contiguous()
if sparse_size is None:
sparse_size = Size((row[-1].item() + 1, col.max().item() + 1))
ones = None
if rowptr is None:
ones = torch.ones_like(row)
rowptr = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
if colptr is None:
ones = torch.ones_like(col) if ones is None else ones
colptr = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
if arg_csr_to_csc is None:
idx = sparse_size[0] * col + row
......@@ -209,20 +217,16 @@ if __name__ == '__main__':
data = dataset[0].to(device)
edge_index = data.edge_index
row, col = edge_index
print(row.size())
print(row[:20])
print(col[:20])
print('--------')
# storage = SparseStorage(row, col)
idx = data.num_nodes * col + row
perm = idx.argsort()
row, col = row[perm], col[perm]
print(row[:20])
print(col[:20])
print('--------')
perm = perm.argsort()
row, col = row[perm], col[perm]
print(row[:20])
print(col[:20])
storage = SparseStorage(row, col)
# idx = data.num_nodes * col + row
# perm = idx.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
# print('--------')
# perm = perm.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment