utils.h 2.08 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#pragma once

rusty1s's avatar
rusty1s committed
3
#include <torch/extension.h>
rusty1s's avatar
rusty1s committed
4
5
6
7

std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
                                                     at::Tensor col) {
  auto mask = row != col;
rusty1s's avatar
rusty1s committed
8
  return std::make_tuple(row.masked_select(mask), col.masked_select(mask));
rusty1s's avatar
rusty1s committed
9
10
11
12
13
}

std::tuple<at::Tensor, at::Tensor, at::Tensor>
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
  auto mask = row != col;
rusty1s's avatar
rusty1s committed
14
15
  return std::make_tuple(row.masked_select(mask), col.masked_select(mask),
                         weight.masked_select(mask));
rusty1s's avatar
rusty1s committed
16
17
18
}

std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
rusty1s's avatar
rusty1s committed
19
  auto perm = at::randperm(row.size(0), row.options());
rusty1s's avatar
rusty1s committed
20
  return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm));
rusty1s's avatar
rusty1s committed
21
22
23
}

std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
rusty1s's avatar
rusty1s committed
24
25
26
  at::Tensor perm;
  std::tie(row, perm) = row.sort();
  return std::make_tuple(row, col.index_select(0, perm));
rusty1s's avatar
rusty1s committed
27
28
}

rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
39
std::tuple<at::Tensor, at::Tensor, at::Tensor>
sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) {
  at::Tensor perm;
  std::tie(row, perm) = row.sort();
  return std::make_tuple(row, col.index_select(0, perm),
                         weight.index_select(0, perm));
}

at::Tensor degree(at::Tensor row, int64_t num_nodes) {
  auto zero = zeros(num_nodes, row.options());
  auto one = ones(row.size(0), row.options());
rusty1s's avatar
rusty1s committed
40
41
42
  return zero.scatter_add_(0, row, one);
}

rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
                                          int64_t num_nodes) {
  std::tie(row, col) = sort_by_row(row, col);
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col);
rusty1s's avatar
rusty1s committed
49
50
}

rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
57
std::tuple<at::Tensor, at::Tensor, at::Tensor>
to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) {
  std::tie(row, col, weight) = sort_by_row(row, col, weight);
  row = degree(row, num_nodes).cumsum(0);
  row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero.
  return std::make_tuple(row, col, weight);
}