permute.py 289 Bytes
Newer Older
1
2
3
import torch


rusty1s's avatar
rusty1s committed
4
def random_permute(edge_index, num_nodes):
5
6
    row, col = edge_index

rusty1s's avatar
rusty1s committed
7
8
    rid = torch.randperm(row.size(0))
    row, col = row[rid], col[rid]
9

rusty1s's avatar
rusty1s committed
10
    _, perm = rid[torch.randperm(num_nodes)].sort()
11
12
    row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
13
    return torch.stack([row, col], dim=0)