sample_cpu.cpp 4.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
9
// Returns `rowptr`, `col`, `n_id`, `e_id`
rusty1s's avatar
rusty1s committed
10
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
11
12
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
               int64_t num_neighbors, bool replace) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(idx);
  CHECK_INPUT(idx.dim() == 1);

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

22
  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
rusty1s's avatar
rusty1s committed
23
24
25
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
  out_rowptr_data[0] = 0;

rusty1s's avatar
rusty1s committed
26
  std::vector<std::vector<std::tuple<int64_t, int64_t>>> cols; // col, e_id
rusty1s's avatar
rusty1s committed
27
28
29
30
  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;

  int64_t i;
31
  for (int64_t n = 0; n < idx.numel(); n++) {
rusty1s's avatar
rusty1s committed
32
    i = idx_data[n];
rusty1s's avatar
rusty1s committed
33
    cols.push_back(std::vector<std::tuple<int64_t, int64_t>>());
rusty1s's avatar
rusty1s committed
34
35
36
37
    n_id_map[i] = n;
    n_ids.push_back(i);
  }

38
39
  int64_t n, c, e, row_start, row_end, row_count;

rusty1s's avatar
rusty1s committed
40
41
  if (num_neighbors < 0) { // No sampling ======================================

42
43
44
45
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
46

47
48
      for (int64_t j = 0; j < row_count; j++) {
        e = row_start + j;
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
55
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
56
      }
57
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
58
59
60
61
    }
  }

  else if (replace) { // Sample with replacement ===============================
rusty1s's avatar
rusty1s committed
62

63
64
65
66
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
67

68
69
      if (row_count > 0) {
        for (int64_t j = 0; j < num_neighbors; j++) {
70
          e = row_start + uniform_randint(row_count);
71
72
73
74
75
76
77
          c = col_data[e];

          if (n_id_map.count(c) == 0) {
            n_id_map[c] = n_ids.size();
            n_ids.push_back(c);
          }
          cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
78
79
        }
      }
80
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
81
82
83
84
    }

  } else { // Sample without replacement via Robert Floyd algorithm ============

85
86
87
88
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
89
90

      std::unordered_set<int64_t> perm;
91
92
93
94
95
96
      if (row_count <= num_neighbors) {
        for (int64_t j = 0; j < row_count; j++)
          perm.insert(j);
      } else { // See: https://www.nowherenearithaca.com/2013/05/
               //      robert-floyds-tiny-and-beautiful.html
        for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
97
          if (!perm.insert(uniform_randint(j)).second)
98
            perm.insert(j);
rusty1s's avatar
rusty1s committed
99
100
101
102
        }
      }

      for (const int64_t &p : perm) {
103
        e = row_start + p;
rusty1s's avatar
rusty1s committed
104
105
106
107
108
109
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
110
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
111
      }
112
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
113
114
115
    }
  }

rusty1s's avatar
rusty1s committed
116
117
  int64_t N = n_ids.size();
  auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
120
  int64_t E = out_rowptr_data[idx.numel()];
  auto out_col = torch::empty(E, col.options());
121
  auto out_col_data = out_col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
122
123
  auto out_e_id = torch::empty(E, col.options());
  auto out_e_id_data = out_e_id.data_ptr<int64_t>();
124
125

  i = 0;
rusty1s's avatar
rusty1s committed
126
127
128
129
130
131
132
133
134
  for (std::vector<std::tuple<int64_t, int64_t>> &col_vec : cols) {
    std::sort(col_vec.begin(), col_vec.end(),
              [](const std::tuple<int64_t, int64_t> &a,
                 const std::tuple<int64_t, int64_t> &b) -> bool {
                return std::get<0>(a) < std::get<0>(b);
              });
    for (const std::tuple<int64_t, int64_t> &value : col_vec) {
      out_col_data[i] = std::get<0>(value);
      out_e_id_data[i] = std::get<1>(value);
135
136
137
138
      i += 1;
    }
  }

rusty1s's avatar
rusty1s committed
139
  return std::make_tuple(out_rowptr, out_col, out_n_id, out_e_id);
rusty1s's avatar
rusty1s committed
140
}