"docs/vscode:/vscode.git/clone" did not exist on "b500df11559265857d6b51685affdc13822f625f"
sample_cpu.cpp 4.21 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
// Returns `rowptr`, `col`, `n_id`, `e_id`
rusty1s's avatar
rusty1s committed
6
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
7
8
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
               int64_t num_neighbors, bool replace) {
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(idx);
  CHECK_INPUT(idx.dim() == 1);

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

18
  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
rusty1s's avatar
rusty1s committed
19
20
21
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
  out_rowptr_data[0] = 0;

rusty1s's avatar
rusty1s committed
22
  std::vector<std::vector<std::tuple<int64_t, int64_t>>> cols; // col, e_id
rusty1s's avatar
rusty1s committed
23
24
25
26
  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;

  int64_t i;
27
  for (int64_t n = 0; n < idx.numel(); n++) {
rusty1s's avatar
rusty1s committed
28
    i = idx_data[n];
rusty1s's avatar
rusty1s committed
29
    cols.push_back(std::vector<std::tuple<int64_t, int64_t>>());
rusty1s's avatar
rusty1s committed
30
31
32
33
    n_id_map[i] = n;
    n_ids.push_back(i);
  }

34
35
  int64_t n, c, e, row_start, row_end, row_count;

rusty1s's avatar
rusty1s committed
36
37
  if (num_neighbors < 0) { // No sampling ======================================

38
39
40
41
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
42

43
44
      for (int64_t j = 0; j < row_count; j++) {
        e = row_start + j;
rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
51
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
52
      }
53
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
54
55
56
57
    }
  }

  else if (replace) { // Sample with replacement ===============================
rusty1s's avatar
rusty1s committed
58

59
60
61
62
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
63
64

      for (int64_t j = 0; j < num_neighbors; j++) {
65
        e = row_start + rand() % row_count;
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
72
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
73
      }
74
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
75
76
77
78
    }

  } else { // Sample without replacement via Robert Floyd algorithm ============

79
80
81
82
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
83
84

      std::unordered_set<int64_t> perm;
85
86
87
88
89
90
91
92
      if (row_count <= num_neighbors) {
        for (int64_t j = 0; j < row_count; j++)
          perm.insert(j);
      } else { // See: https://www.nowherenearithaca.com/2013/05/
               //      robert-floyds-tiny-and-beautiful.html
        for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
          if (!perm.insert(rand() % j).second)
            perm.insert(j);
rusty1s's avatar
rusty1s committed
93
94
95
96
        }
      }

      for (const int64_t &p : perm) {
97
        e = row_start + p;
rusty1s's avatar
rusty1s committed
98
99
100
101
102
103
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
104
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
105
      }
106
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
107
108
109
    }
  }

rusty1s's avatar
rusty1s committed
110
111
  int64_t N = n_ids.size();
  auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
rusty1s's avatar
rusty1s committed
112

rusty1s's avatar
rusty1s committed
113
114
  int64_t E = out_rowptr_data[idx.numel()];
  auto out_col = torch::empty(E, col.options());
115
  auto out_col_data = out_col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
116
117
  auto out_e_id = torch::empty(E, col.options());
  auto out_e_id_data = out_e_id.data_ptr<int64_t>();
118
119

  i = 0;
rusty1s's avatar
rusty1s committed
120
121
122
123
124
125
126
127
128
  for (std::vector<std::tuple<int64_t, int64_t>> &col_vec : cols) {
    std::sort(col_vec.begin(), col_vec.end(),
              [](const std::tuple<int64_t, int64_t> &a,
                 const std::tuple<int64_t, int64_t> &b) -> bool {
                return std::get<0>(a) < std::get<0>(b);
              });
    for (const std::tuple<int64_t, int64_t> &value : col_vec) {
      out_col_data[i] = std::get<0>(value);
      out_e_id_data[i] = std::get<1>(value);
129
130
131
132
      i += 1;
    }
  }

rusty1s's avatar
rusty1s committed
133
  return std::make_tuple(out_rowptr, out_col, out_n_id, out_e_id);
rusty1s's avatar
rusty1s committed
134
}