permute.py 397 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch


def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
    row, col = edge_index

    edge_rid = torch.randperm(row.size(0)) if edge_rid is None else edge_rid
    row, col = row[edge_rid], col[edge_rid]

    node_rid = torch.randperm(num_nodes) if node_rid is None else node_rid
    _, perm = node_rid[row].sort()
    row, col = row[perm], col[perm]

    return row, col