Commit b3746aab authored by rusty1s's avatar rusty1s
Browse files

sorting

parent 3307f0d9
...@@ -12,28 +12,36 @@ class SparseStorage(object): ...@@ -12,28 +12,36 @@ class SparseStorage(object):
assert row.device == row.device assert row.device == row.device
assert row.dim() == 1 and col.dim() == 1 and row.numel() == col.numel() assert row.dim() == 1 and col.dim() == 1 and row.numel() == col.numel()
if sparse_size is None:
sparse_size = Size((row.max().item() + 1, col.max().item() + 1))
if not is_sorted: if not is_sorted:
# Sort row and col idx = sparse_size[1] * row + col
rowptr = None # Only sort if necessary...
colptr = None if (idx <= torch.cat([idx.new_zeros(1), idx[:-1]], dim=0)).any():
arg_csr_to_csc = None perm = idx.argsort()
arg_csc_to_csr = None row = row[perm]
col = col[perm]
value = None if value is None else value[perm]
rowptr = None
colptr = None
arg_csr_to_csc = None
arg_csc_to_csr = None
if value is not None: if value is not None:
assert row.device == value.device and value.size(0) == row.size(0) assert row.device == value.device and value.size(0) == row.size(0)
value = value.contiguous() value = value.contiguous()
if sparse_size is None:
sparse_size = Size((row[-1].item() + 1, col.max().item() + 1))
ones = None ones = None
if rowptr is None: if rowptr is None:
ones = torch.ones_like(row) ones = torch.ones_like(row)
rowptr = segment_add(ones, row, dim=0, dim_size=sparse_size[0]) out_deg = segment_add(ones, row, dim=0, dim_size=sparse_size[0])
rowptr = torch.cat([row.new_zeros(1), out_deg.cumsum(0)], dim=0)
if colptr is None: if colptr is None:
ones = torch.ones_like(col) if ones is None else ones ones = torch.ones_like(col) if ones is None else ones
colptr = scatter_add(ones, col, dim=0, dim_size=sparse_size[1]) in_deg = scatter_add(ones, col, dim=0, dim_size=sparse_size[1])
colptr = torch.cat([col.new_zeros(1), in_deg.cumsum(0)], dim=0)
if arg_csr_to_csc is None: if arg_csr_to_csc is None:
idx = sparse_size[0] * col + row idx = sparse_size[0] * col + row
...@@ -209,20 +217,16 @@ if __name__ == '__main__': ...@@ -209,20 +217,16 @@ if __name__ == '__main__':
data = dataset[0].to(device) data = dataset[0].to(device)
edge_index = data.edge_index edge_index = data.edge_index
row, col = edge_index row, col = edge_index
print(row.size())
print(row[:20]) storage = SparseStorage(row, col)
print(col[:20]) # idx = data.num_nodes * col + row
print('--------') # perm = idx.argsort()
# row, col = row[perm], col[perm]
# storage = SparseStorage(row, col) # print(row[:20])
idx = data.num_nodes * col + row # print(col[:20])
perm = idx.argsort() # print('--------')
row, col = row[perm], col[perm]
print(row[:20]) # perm = perm.argsort()
print(col[:20]) # row, col = row[perm], col[perm]
print('--------') # print(row[:20])
# print(col[:20])
perm = perm.argsort()
row, col = row[perm], col[perm]
print(row[:20])
print(col[:20])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment