sample_cpu.cpp 3.83 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "sample_cpu.h"

#include "utils.h"

// Returns `rowptr`, `col`, `n_id`, `e_id`,
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
7
8
sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
               int64_t num_neighbors, bool replace) {
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  CHECK_CPU(idx);
  CHECK_INPUT(idx.dim() == 1);

  auto rowptr_data = rowptr.data_ptr<int64_t>();
  auto col_data = col.data_ptr<int64_t>();
  auto idx_data = idx.data_ptr<int64_t>();

18
  auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
rusty1s's avatar
rusty1s committed
19
20
21
  auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
  out_rowptr_data[0] = 0;

22
  std::vector<std::multiset<int64_t>> cols;
rusty1s's avatar
rusty1s committed
23
24
25
26
27
  std::vector<int64_t> n_ids;
  std::unordered_map<int64_t, int64_t> n_id_map;
  std::vector<int64_t> e_ids;

  int64_t i;
28
  for (int64_t n = 0; n < idx.numel(); n++) {
rusty1s's avatar
rusty1s committed
29
    i = idx_data[n];
30
    cols.push_back(std::multiset<int64_t>());
rusty1s's avatar
rusty1s committed
31
32
33
34
    n_id_map[i] = n;
    n_ids.push_back(i);
  }

35
36
  int64_t n, c, e, row_start, row_end, row_count;

rusty1s's avatar
rusty1s committed
37
38
  if (num_neighbors < 0) { // No sampling ======================================

39
40
41
42
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
43

44
45
      for (int64_t j = 0; j < row_count; j++) {
        e = row_start + j;
rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }

53
        cols[i].insert(n_id_map[c]);
rusty1s's avatar
rusty1s committed
54
55
        e_ids.push_back(e);
      }
56
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
57
58
59
60
    }
  }

  else if (replace) { // Sample with replacement ===============================
rusty1s's avatar
rusty1s committed
61

62
63
64
65
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
66
67

      for (int64_t j = 0; j < num_neighbors; j++) {
68
        e = row_start + rand() % row_count;
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }

76
77
        cols[i].insert(n_id_map[c]);
        e_ids.push_back(c);
rusty1s's avatar
rusty1s committed
78
      }
79
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
80
81
82
83
    }

  } else { // Sample without replacement via Robert Floyd algorithm ============

84
85
86
87
    for (int64_t i = 0; i < idx.numel(); i++) {
      n = idx_data[i];
      row_start = rowptr_data[n], row_end = rowptr_data[n + 1];
      row_count = row_end - row_start;
rusty1s's avatar
rusty1s committed
88
89

      std::unordered_set<int64_t> perm;
90
91
92
93
94
95
96
97
      if (row_count <= num_neighbors) {
        for (int64_t j = 0; j < row_count; j++)
          perm.insert(j);
      } else { // See: https://www.nowherenearithaca.com/2013/05/
               //      robert-floyds-tiny-and-beautiful.html
        for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
          if (!perm.insert(rand() % j).second)
            perm.insert(j);
rusty1s's avatar
rusty1s committed
98
99
100
101
        }
      }

      for (const int64_t &p : perm) {
102
        e = row_start + p;
rusty1s's avatar
rusty1s committed
103
104
105
106
107
108
109
        c = col_data[e];

        if (n_id_map.count(c) == 0) {
          n_id_map[c] = n_ids.size();
          n_ids.push_back(c);
        }

110
111
        cols[i].insert(n_id_map[c]);
        e_ids.push_back(c);
rusty1s's avatar
rusty1s committed
112
      }
113
      out_rowptr_data[i + 1] = out_rowptr_data[i] + cols[i].size();
rusty1s's avatar
rusty1s committed
114
115
116
    }
  }

117
  int64_t n_len = n_ids.size(), e_len = e_ids.size();
rusty1s's avatar
rusty1s committed
118
119
120
  auto n_id = torch::from_blob(n_ids.data(), {n_len}, col.options()).clone();
  auto e_id = torch::from_blob(e_ids.data(), {e_len}, col.options()).clone();

121
122
123
124
125
126
127
128
129
130
131
132
  auto out_col = torch::empty(e_len, col.options());
  auto out_col_data = out_col.data_ptr<int64_t>();

  i = 0;
  for (const std::multiset<int64_t> &col_set : cols) {
    for (const int64_t &c : col_set) {
      out_col_data[i] = c;
      i += 1;
    }
  }

  return std::make_tuple(out_rowptr, out_col, n_id, e_id);
rusty1s's avatar
rusty1s committed
133
}