perm.cpp 704 Bytes
Newer Older
rusty1s's avatar
new try  
rusty1s committed
1
2
#ifndef PERM_INC
#define PERM_INC
rusty1s's avatar
rusty1s committed
3
4
5
6
7
8

#include <torch/torch.h>

inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
  // Randomly reorder row and column indices.
rusty1s's avatar
cleanup  
rusty1s committed
9
  auto perm = at::randperm(row.type(), row.size(0));
rusty1s's avatar
rusty1s committed
10
11
12
13
  row = row.index_select(0, perm);
  col = col.index_select(0, perm);

  // Randomly swap row values.
rusty1s's avatar
cleanup  
rusty1s committed
14
  auto node_rid = at::randperm(row.type(), num_nodes);
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
24
25
  row = node_rid.index_select(0, row);

  // Sort row and column indices row-wise.
  std::tie(row, perm) = row.sort();
  col = col.index_select(0, perm);

  // Revert row value swaps.
  row = std::get<1>(node_rid.sort()).index_select(0, row);

  return {row, col};
}
rusty1s's avatar
new try  
rusty1s committed
26
27

#endif // PERM_INC