utils.h 2.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
#pragma once

#include <torch/torch.h>

std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
                                                     at::Tensor col) {
  auto mask = row != col;
rusty1s's avatar
rusty1s committed
8
  return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
rusty1s's avatar
rusty1s committed
9
10
11
12
13
}

std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
  auto mask = row != col;
rusty1s's avatar
rusty1s committed
14
15
  return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
                         weight.masked_select(mask));
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
22
23
24
25
}

at::Tensor randperm(int64_t n) {
  auto out = at::empty(n, torch::CPU(at::kLong));
  at::randperm_out(out, n);
  return out;
}

std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
  auto perm = randperm(row.size(0));
rusty1s's avatar
rusty1s committed
26
  return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
rusty1s's avatar
rusty1s committed
27
28
29
}

std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
rusty1s's avatar
rusty1s committed
30
31
32
  at::Tensor perm;
  std::tie(row, perm) = row.sort();
  return std::make_tuple(row, col.index_select(0, perm));
rusty1s's avatar
rusty1s committed
33
34
}

rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
  at::Tensor perm;
  std::tie(row, perm) = row.sort();
  return std::make_tuple(row, col.index_select(0, perm),
                         weight.index_select(0, perm));
}

at::Tensor degree(at::Tensor row, int64_t num_nodes) {
  auto zero = zeros(num_nodes, row.options());
  auto one = ones(row.size(0), row.options());
rusty1s's avatar
rusty1s committed
46
47
48
  return zero.scatter_add_(0, row, one);
}

rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
                                          int64_t num_nodes) {
  std::tie(row, col) = sort_by_row(row, col);
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col);
rusty1s's avatar
rusty1s committed
55
56
}

rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
63
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
  std::tie(row, col, weight) = sort_by_row(row, col, weight);
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col, weight);
}