import torch def random_permute(edge_index, num_nodes): row, col = edge_index rid = torch.randperm(row.size(0)) row, col = row[rid], col[rid] _, perm = rid[torch.randperm(num_nodes)].sort() row, col = row[perm], col[perm] return torch.stack([row, col], dim=0)