"tests/vscode:/vscode.git/clone" did not exist on "f80b303ceea2efe225577a61985def72f8a16627"
sample_cpu.cpp 4.37 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
9
// Returns `rowptr`, `col`, `n_id`, `e_id`
rusty1s's avatar
rusty1s committed
10
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
11
12
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
               int64_t num_neighbors, bool replace) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(idx);
  CHECK_INPUT(idx.dim() == 1);

rusty1s's avatar
rusty1s committed
18
19
  srand(time(NULL) + 1000 * getpid()); // Initialize random seed.

rusty1s's avatar
rusty1s committed
20
21
22
23
  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

24
  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
rusty1s's avatar
rusty1s committed
25
26
27
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
  out_rowptr_data[0] = 0;

rusty1s's avatar
rusty1s committed
28
  std::vector<std::vector<std::tuple<int64_t, int64_t>>> cols; // col, e_id
rusty1s's avatar
rusty1s committed
29
30
31
32
  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;

  int64_t i;
33
  for (int64_t n = 0; n < idx.numel(); n++) {
rusty1s's avatar
rusty1s committed
34
    i = idx_data[n];
rusty1s's avatar
rusty1s committed
35
    cols.push_back(std::vector<std::tuple<int64_t, int64_t>>());
rusty1s's avatar
rusty1s committed
36
37
38
39
    n_id_map[i] = n;
    n_ids.push_back(i);
  }

40
41
  int64_t n, c, e, row_start, row_end, row_count;

rusty1s's avatar
rusty1s committed
42
43
  if (num_neighbors < 0) { // No sampling ======================================

44
45
46
47
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
48

49
50
      for (int64_t j = 0; j < row_count; j++) {
        e = row_start + j;
rusty1s's avatar
rusty1s committed
51
52
53
54
55
56
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
57
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
58
      }
59
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
60
61
62
63
    }
  }

  else if (replace) { // Sample with replacement ===============================
rusty1s's avatar
rusty1s committed
64

65
66
67
68
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
69

70
71
72
73
74
75
76
77
78
79
      if (row_count > 0) {
        for (int64_t j = 0; j < num_neighbors; j++) {
          e = row_start + rand() % row_count;
          c = col_data[e];

          if (n_id_map.count(c) == 0) {
            n_id_map[c] = n_ids.size();
            n_ids.push_back(c);
          }
          cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
80
81
        }
      }
82
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
83
84
85
86
    }

  } else { // Sample without replacement via Robert Floyd algorithm ============

87
88
89
90
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
91
92

      std::unordered_set<int64_t> perm;
93
94
95
96
97
98
99
100
      if (row_count <= num_neighbors) {
        for (int64_t j = 0; j < row_count; j++)
          perm.insert(j);
      } else { // See: https://www.nowherenearithaca.com/2013/05/
               //      robert-floyds-tiny-and-beautiful.html
        for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
          if (!perm.insert(rand() % j).second)
            perm.insert(j);
rusty1s's avatar
rusty1s committed
101
102
103
104
        }
      }

      for (const int64_t &p : perm) {
105
        e = row_start + p;
rusty1s's avatar
rusty1s committed
106
107
108
109
110
111
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }
rusty1s's avatar
rusty1s committed
112
        cols[i].push_back(std::make_tuple(n_id_map[c], e));
rusty1s's avatar
rusty1s committed
113
      }
114
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
115
116
117
    }
  }

rusty1s's avatar
rusty1s committed
118
119
  int64_t N = n_ids.size();
  auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
  int64_t E = out_rowptr_data[idx.numel()];
  auto out_col = torch::empty(E, col.options());
123
  auto out_col_data = out_col.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
124
125
  auto out_e_id = torch::empty(E, col.options());
  auto out_e_id_data = out_e_id.data_ptr<int64_t>();
126
127

  i = 0;
rusty1s's avatar
rusty1s committed
128
129
130
131
132
133
134
135
136
  for (std::vector<std::tuple<int64_t, int64_t>> &col_vec : cols) {
    std::sort(col_vec.begin(), col_vec.end(),
              [](const std::tuple<int64_t, int64_t> &a,
                 const std::tuple<int64_t, int64_t> &b) -> bool {
                return std::get<0>(a) < std::get<0>(b);
              });
    for (const std::tuple<int64_t, int64_t> &value : col_vec) {
      out_col_data[i] = std::get<0>(value);
      out_e_id_data[i] = std::get<1>(value);
137
138
139
140
      i += 1;
    }
  }

rusty1s's avatar
rusty1s committed
141
  return std::make_tuple(out_rowptr, out_col, out_n_id, out_e_id);
rusty1s's avatar
rusty1s committed
142
}