permute.py 842 Bytes
Newer Older
1
2
3
import torch


rusty1s's avatar
tests  
rusty1s committed
4
def sort(edge_index):
5
    row, col = edge_index
rusty1s's avatar
tests  
rusty1s committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    row, perm = row.sort()
    col = col[perm]
    return torch.stack([row, col], dim=0)


def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
    num_edges = edge_index.size(1)

    # Randomly reorder row and column indices.
    if edge_rid is None:
        edge_rid = torch.randperm(num_edges).type_as(edge_index)
    row, col = edge_index[:, edge_rid]

    # Randomly change row indices to new values.
    if node_rid is None:
        node_rid = torch.randperm(num_nodes).type_as(edge_index)
    row = node_rid[row]
23

rusty1s's avatar
tests  
rusty1s committed
24
25
    # Sort row and column indices based on changed values.
    row, col = sort(torch.stack([row, col], dim=0))
26

rusty1s's avatar
tests  
rusty1s committed
27
28
    # Revert previous row value changes to old indices.
    row = node_rid.sort()[1][row]
29

rusty1s's avatar
rusty1s committed
30
    return torch.stack([row, col], dim=0)