permute.py 389 Bytes
Newer Older
1
2
3
import torch


rusty1s's avatar
rusty1s committed
4
def permute(edge_index, num_nodes, rid=None, perm_edges=True):
5
6
    row, col = edge_index

rusty1s's avatar
rusty1s committed
7
8
9
    if perm_edges:
        edge_rid = torch.randperm(row.size(0))
        row, col = row[edge_rid], col[edge_rid]
10

rusty1s's avatar
rusty1s committed
11
12
    rid = torch.randperm(num_nodes) if rid is None else rid
    _, perm = rid[row].sort()
13
14
    row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
15
    return torch.stack([row, col], dim=0)