perm.py 657 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
import torch


def randperm(row, col):
    # Randomly reorder row and column indices.
rusty1s's avatar
rusty1s committed
6
    edge_rid = torch.randperm(row.size(0))
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
    return row[edge_rid], col[edge_rid]


def sort_row(row, col):
    # Sort row and column indices row-wise.
    row, perm = row.sort()
    col = col[perm]
    return row, col


def randperm_sort_row(row, col, num_nodes):
    # Randomly change row indices to new values.
rusty1s's avatar
rusty1s committed
19
    node_rid = torch.randperm(num_nodes)
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
    row = node_rid[row]

    # Sort row and column indices row-wise.
    row, col = sort_row(row, col)

    # Revert previous row value changes to old indices.
    row = node_rid.sort()[1][row]

    return row, col