perm.cpp 726 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#ifndef PERM_CPP
#define PERM_CPP

#include <torch/torch.h>

inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
  // Randomly reorder row and column indices.
  auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
  row = row.index_select(0, perm);
  col = col.index_select(0, perm);

  // Randomly swap row values.
  auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
  row = node_rid.index_select(0, row);

  // Sort row and column indices row-wise.
  std::tie(row, perm) = row.sort();
  col = col.index_select(0, perm);

  // Revert row value swaps.
  row = std::get<1>(node_rid.sort()).index_select(0, row);

  return {row, col};
}

#endif // PERM_CPP