permute.py 730 Bytes
Newer Older
1
2
3
import torch


rusty1s's avatar
rusty1s committed
4
def sort(row, col):
rusty1s's avatar
tests  
rusty1s committed
5
6
    row, perm = row.sort()
    col = col[perm]
rusty1s's avatar
rusty1s committed
7
    return row, col
rusty1s's avatar
tests  
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
11
def permute(row, col, num_nodes, node_rid=None, edge_rid=None):
    num_edges = row.size(0)
rusty1s's avatar
tests  
rusty1s committed
12
13
14

    # Randomly reorder row and column indices.
    if edge_rid is None:
rusty1s's avatar
rusty1s committed
15
16
        edge_rid = torch.randperm(num_edges).type_as(row)
    row, col = row[edge_rid], col[edge_rid]
rusty1s's avatar
tests  
rusty1s committed
17
18
19

    # Randomly change row indices to new values.
    if node_rid is None:
rusty1s's avatar
rusty1s committed
20
        node_rid = torch.randperm(num_nodes).type_as(row)
rusty1s's avatar
tests  
rusty1s committed
21
    row = node_rid[row]
22

rusty1s's avatar
tests  
rusty1s committed
23
    # Sort row and column indices based on changed values.
rusty1s's avatar
rusty1s committed
24
    row, col = sort(row, col)
25

rusty1s's avatar
tests  
rusty1s committed
26
27
    # Revert previous row value changes to old indices.
    row = node_rid.sort()[1][row]
28

rusty1s's avatar
rusty1s committed
29
    return row, col