sample_cpu.cpp 4.27 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
// Returns `rowptr`, `col`, `n_id`, `e_id`
rusty1s's avatar
rusty1s committed
6
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
7
8
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
               int64_t num_neighbors, bool replace) {
rusty1s's avatar
rusty1s committed
9
10
11
12
13
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(idx);
  CHECK_INPUT(idx.dim() == 1);

rusty1s's avatar
rusty1s committed
14
15
  srand(time(NULL) + 1000 * getpid()); // Initialize random seed.

rusty1s's avatar
rusty1s committed
16
17
18
19
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

20
  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
rusty1s's avatar
rusty1s committed
21
22
23
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
  out_rowptr_data[0] = 0;

rusty1s's avatar
rusty1s committed
24
  std::vector<std::vector<std::tuple<int64_t, int64_t>>> cols; // col, e_id
rusty1s's avatar
rusty1s committed
25
26
27
28
  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;

  int64_t i;
29
  for (int64_t n = 0; n < idx.numel(); n++) {
rusty1s's avatar
rusty1s committed
30
    i = idx_data[n];
rusty1s's avatar
rusty1s committed
31
    cols.push_back(std::vector<std::tuple<int64_t, int64_t>>());
rusty1s's avatar
rusty1s committed
32
33
34
35
    n_id_map[i] = n;
    n_ids.push_back(i);
  }

36
37
  int64_t n, c, e, row_start, row_end, row_count;

rusty1s's avatar
rusty1s committed
38
39
  if (num_neighbors < 0) { // No sampling ======================================

40
41
42
43
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
44

45
46
      for (int64_t j = 0; j < row_count; j++) {
        e = row_start + j;
rusty1s's avatar
rusty1s committed
47
48
49
50
51
52
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
53
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
54
      }
55
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
56
57
58
59
    }
  }

  else if (replace) { // Sample with replacement ===============================
rusty1s's avatar
rusty1s committed
60

61
62
63
64
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
65
66

      for (int64_t j = 0; j < num_neighbors; j++) {
67
        e = row_start + rand() % row_count;
rusty1s's avatar
rusty1s committed
68
69
70
71
72
73
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
74
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
75
      }
76
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
77
78
79
80
    }

  } else { // Sample without replacement via Robert Floyd algorithm ============

81
82
83
84
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
85
86

      std::unordered_set<int64_t> perm;
87
88
89
90
91
92
93
94
      if (row_count <= num_neighbors) {
        for (int64_t j = 0; j < row_count; j++)
          perm.insert(j);
      } else { // See: https://www.nowherenearithaca.com/2013/05/
               //      robert-floyds-tiny-and-beautiful.html
        for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
          if (!perm.insert(rand() % j).second)
            perm.insert(j);
rusty1s's avatar
rusty1s committed
95
96
97
98
        }
      }

      for (const int64_t &p : perm) {
99
        e = row_start + p;
rusty1s's avatar
rusty1s committed
100
101
102
103
104
105
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
106
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
107
      }
108
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
109
110
111
    }
  }

rusty1s's avatar
rusty1s committed
112
113
  int64_t N = n_ids.size();
  auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
rusty1s committed
115
116
  int64_t E = out_rowptr_data[idx.numel()];
  auto out_col = torch::empty(E, col.options());
117
  auto out_col_data = out_col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
118
119
  auto out_e_id = torch::empty(E, col.options());
  auto out_e_id_data = out_e_id.data_ptr<int64_t>();
120
121

  i = 0;
rusty1s's avatar
rusty1s committed
122
123
124
125
126
127
128
129
130
  for (std::vector<std::tuple<int64_t, int64_t>> &col_vec : cols) {
    std::sort(col_vec.begin(), col_vec.end(),
              [](const std::tuple<int64_t, int64_t> &a,
                 const std::tuple<int64_t, int64_t> &b) -> bool {
                return std::get<0>(a) < std::get<0>(b);
              });
    for (const std::tuple<int64_t, int64_t> &value : col_vec) {
      out_col_data[i] = std::get<0>(value);
      out_e_id_data[i] = std::get<1>(value);
131
132
133
134
      i += 1;
    }
  }

rusty1s's avatar
rusty1s committed
135
  return std::make_tuple(out_rowptr, out_col, out_n_id, out_e_id);
rusty1s's avatar
rusty1s committed
136
}