permute.py 770 Bytes
Newer Older
1
2
3
import torch


rusty1s's avatar
rusty1s committed
4
def sort(row, col):
rusty1s's avatar
tests  
rusty1s committed
5
6
    row, perm = row.sort()
    col = col[perm]
rusty1s's avatar
rusty1s committed
7
    return row, col
rusty1s's avatar
tests  
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21


def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
    num_edges = edge_index.size(1)

    # Randomly reorder row and column indices.
    if edge_rid is None:
        edge_rid = torch.randperm(num_edges).type_as(edge_index)
    row, col = edge_index[:, edge_rid]

    # Randomly change row indices to new values.
    if node_rid is None:
        node_rid = torch.randperm(num_nodes).type_as(edge_index)
    row = node_rid[row]
22

rusty1s's avatar
tests  
rusty1s committed
23
    # Sort row and column indices based on changed values.
rusty1s's avatar
rusty1s committed
24
    row, col = sort(row, col)
25

rusty1s's avatar
tests  
rusty1s committed
26
27
    # Revert previous row value changes to old indices.
    row = node_rid.sort()[1][row]
28

rusty1s's avatar
rusty1s committed
29
    return torch.stack([row, col], dim=0)