hgt_sample_cpu.cpp 8.99 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "hgt_sample_cpu.h"

rusty1s's avatar
rusty1s committed
3
#include "utils.h"
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
#define MAX_NEIGHBORS 50
rusty1s's avatar
update  
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
edge_t split(const rel_t &rel_type) {
rusty1s's avatar
rusty1s committed
8
9
  vector<string> result(3);
  int start = 0, end;
rusty1s's avatar
rusty1s committed
10
  for (int i = 0; i < 3; i++) {
rusty1s's avatar
rusty1s committed
11
    end = rel_type.find("__", start);
rusty1s's avatar
rusty1s committed
12
13
14
    result[i] = rel_type.substr(start, end - start);
    start = end + 2;
  }
rusty1s's avatar
rusty1s committed
15
  return make_tuple(result[0], result[1], result[2]);
rusty1s's avatar
rusty1s committed
16
17
}

rusty1s's avatar
rusty1s committed
18
19
void update_budget_(
    unordered_map<node_t, unordered_map<int64_t, float>> *budget_dict,
rusty1s's avatar
rusty1s committed
20
    const node_t &node_type, //
rusty1s's avatar
rusty1s committed
21
22
23
24
    const vector<int64_t> &samples,
    const unordered_map<node_t, unordered_map<int64_t, int64_t>>
        &to_local_node_dict,
    const unordered_map<rel_t, edge_t> &to_edge_type,
rusty1s's avatar
update  
rusty1s committed
25
26
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict) {
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
29
30
  if (samples.empty())
    return;

rusty1s's avatar
update  
rusty1s committed
31
  for (const auto &kv : colptr_dict) {
rusty1s's avatar
rusty1s committed
32
    const auto &rel_type = kv.key();
rusty1s's avatar
rusty1s committed
33
34
35
    const auto &edge_type = to_edge_type.at(rel_type);
    const auto &src_node_type = get<0>(edge_type);
    const auto &dst_node_type = get<2>(edge_type);
rusty1s's avatar
rusty1s committed
36
37
38
39

    if (node_type != dst_node_type)
      continue;

rusty1s's avatar
rusty1s committed
40
    const auto &to_local_src_node = to_local_node_dict.at(src_node_type);
rusty1s's avatar
update  
rusty1s committed
41
42
    const auto *colptr_data = kv.value().data_ptr<int64_t>();
    const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
43
    auto &src_budget = budget_dict->at(src_node_type);
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
46
47
    for (const auto &w : samples) {
      const auto &col_start = colptr_data[w], &col_end = colptr_data[w + 1];
      if (col_end - col_start > MAX_NEIGHBORS) {
rusty1s's avatar
update  
rusty1s committed
48
        // There might be same neighbors with large neighborhood sizes.
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
        // In order to prevent that we fill our budget with many values of low
        // probability, we instead sample a fixed amount without replacement:
        auto indices = choice(col_end - col_start, MAX_NEIGHBORS, false);
        auto *indices_data = indices.data_ptr<int64_t>();
        for (int64_t i = 0; i < indices.numel(); i++) {
          const auto &v = row_data[col_start + indices_data[i]];
rusty1s's avatar
update  
rusty1s committed
55
          // Only add the neighbor in case we have not yet seen it before:
rusty1s's avatar
rusty1s committed
56
57
          if (to_local_src_node.find(v) == to_local_src_node.end())
            src_budget[v] += 1.f / float(MAX_NEIGHBORS);
rusty1s's avatar
update  
rusty1s committed
58
        }
rusty1s's avatar
rusty1s committed
59
60

      } else if (col_end != col_start) {
rusty1s's avatar
update  
rusty1s committed
61
        const auto inv_deg = 1.f / float(col_end - col_start);
rusty1s's avatar
rusty1s committed
62
63
        for (int64_t i = col_start; i < col_end; i++) {
          const auto &v = row_data[i];
rusty1s's avatar
update  
rusty1s committed
64
          // Only add the neighbor in case we have not yet seen it before:
rusty1s's avatar
rusty1s committed
65
66
          if (to_local_src_node.find(v) == to_local_src_node.end())
            src_budget[v] += inv_deg;
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
        }
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
vector<int64_t> sample_from(const unordered_map<int64_t, float> &budget,
                            const int64_t num_samples) {
  vector<int64_t> indices;
  vector<float> weights;
  indices.reserve(budget.size());
  weights.reserve(budget.size());
rusty1s's avatar
update  
rusty1s committed
79
  for (const auto &kv : budget) {
rusty1s's avatar
rusty1s committed
80
81
    indices.push_back(kv.first);
    weights.push_back(kv.second * kv.second);
rusty1s's avatar
update  
rusty1s committed
82
  }
rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
85
86
  const auto weight = from_vector(weights, true);
  const auto sample = choice(budget.size(), num_samples, false, weight);
  const auto *sample_data = sample.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
87

rusty1s's avatar
rusty1s committed
88
89
90
  vector<int64_t> out(sample.numel());
  for (int64_t i = 0; i < sample.numel(); i++) {
    out[i] = indices[sample_data[i]];
rusty1s's avatar
rusty1s committed
91
  }
rusty1s's avatar
rusty1s committed
92
  return out;
rusty1s's avatar
rusty1s committed
93
94
}

rusty1s's avatar
rusty1s committed
95
96
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
rusty1s's avatar
update  
rusty1s committed
97
98
hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
               const c10::Dict<rel_t, torch::Tensor> &row_dict,
rusty1s's avatar
rusty1s committed
99
               const c10::Dict<node_t, torch::Tensor> &input_node_dict,
rusty1s's avatar
rusty1s committed
100
101
               const c10::Dict<node_t, vector<int64_t>> &num_samples_dict,
               const int64_t num_hops) {
rusty1s's avatar
rusty1s committed
102

rusty1s's avatar
rusty1s committed
103
104
  // Create a mapping to convert single string relations to edge type triplets:
  std::unordered_map<rel_t, edge_t> to_edge_type;
rusty1s's avatar
update  
rusty1s committed
105
  for (const auto &kv : colptr_dict) {
rusty1s's avatar
rusty1s committed
106
    const auto &rel_type = kv.key();
rusty1s's avatar
rusty1s committed
107
    to_edge_type[rel_type] = split(rel_type);
rusty1s's avatar
rusty1s committed
108
109
  }

rusty1s's avatar
rusty1s committed
110
111
112
113
  // Initialize some necessary data structures for the sampling process:
  unordered_map<node_t, vector<int64_t>> nodes_dict;
  unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
  unordered_map<node_t, unordered_map<int64_t, float>> budget_dict;
rusty1s's avatar
fix  
rusty1s committed
114
115
  for (const auto &kv : num_samples_dict) {
    const auto &node_type = kv.key();
rusty1s's avatar
rusty1s committed
116
117
    nodes_dict[node_type];
    to_local_node_dict[node_type];
rusty1s's avatar
fix  
rusty1s committed
118
119
    budget_dict[node_type];
  }
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
  // Add the input nodes to the sampled output nodes (line 1):
rusty1s's avatar
rusty1s committed
122
123
124
125
126
  for (const auto &kv : input_node_dict) {
    const auto &node_type = kv.key();
    const auto &input_node = kv.value();
    const auto *input_node_data = input_node.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
127
128
    auto &nodes = nodes_dict.at(node_type);
    auto &to_local_node = to_local_node_dict.at(node_type);
rusty1s's avatar
rusty1s committed
129
    for (int64_t i = 0; i < input_node.numel(); i++) {
rusty1s's avatar
rusty1s committed
130
131
132
      const auto &v = input_node_data[i];
      nodes.push_back(v);
      to_local_node[v] = i;
rusty1s's avatar
rusty1s committed
133
    }
rusty1s's avatar
update  
rusty1s committed
134
  }
rusty1s's avatar
rusty1s committed
135
136
137
138
139
140
141
142
143
144
  b = steady_clock::now();
  std::cout << "3=" << duration_cast<microseconds>(b - a).count() << std::endl;

  a = steady_clock::now();
  // Update the budget based on the initial input set (line 3-5):
  for (const auto &kv : nodes_dict) {
    const auto &node_type = kv.first;
    const auto &last_samples = kv.second;
    update_budget_(&budget_dict, node_type, last_samples, to_local_node_dict,
                   to_edge_type, colptr_dict, row_dict);
rusty1s's avatar
rusty1s committed
145
146
147
  }

  for (int64_t ell = 0; ell < num_hops; ell++) {
rusty1s's avatar
rusty1s committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    unordered_map<node_t, vector<int64_t>> samples_dict;
    for (auto &kv : budget_dict) {
      const auto &node_type = kv.first;
      auto &budget = kv.second;
      const auto num_samples = num_samples_dict.at(node_type)[ell];

      // Sample `num_samples` nodes, according to the budget (line 9-11):
      const auto samples = sample_from(budget, num_samples);
      samples_dict[node_type] = samples;

      // Add samples to the sampled output nodes, and erase them from the budget
      // (line 13/15):
      auto &nodes = nodes_dict.at(node_type);
      auto &to_local_node = to_local_node_dict.at(node_type);
      for (const auto &v : samples) {
        to_local_node[v] = nodes.size();
        nodes.push_back(v);
        budget.erase(v);
      }
rusty1s's avatar
update  
rusty1s committed
167
    }
rusty1s's avatar
rusty1s committed
168

rusty1s's avatar
rusty1s committed
169
170
171
172
173
174
175
176
    if (ell < num_hops - 1) {
      // Add neighbors of newly sampled nodes to the budget (line 14):
      // Note that we do not need to update the budget in the last iteration.
      for (const auto &kv : samples_dict) {
        const auto &node_type = kv.first;
        const auto &last_samples = kv.second;
        update_budget_(&budget_dict, node_type, last_samples,
                       to_local_node_dict, to_edge_type, colptr_dict, row_dict);
rusty1s's avatar
update  
rusty1s committed
177
      }
rusty1s's avatar
rusty1s committed
178
179
180
    }
  }

rusty1s's avatar
rusty1s committed
181
182
183
184
  c10::Dict<node_t, torch::Tensor> out_node_dict;
  c10::Dict<rel_t, torch::Tensor> out_row_dict;
  c10::Dict<rel_t, torch::Tensor> out_col_dict;
  c10::Dict<rel_t, torch::Tensor> out_edge_dict;
rusty1s's avatar
update  
rusty1s committed
185

rusty1s's avatar
rusty1s committed
186
  // Reconstruct the sampled adjacency matrix among the sampled nodes (line 19):
rusty1s's avatar
update  
rusty1s committed
187
  for (const auto &kv : colptr_dict) {
rusty1s's avatar
rusty1s committed
188
    const auto &rel_type = kv.key();
rusty1s's avatar
rusty1s committed
189
190
191
    const auto &edge_type = to_edge_type.at(rel_type);
    const auto &src_node_type = get<0>(edge_type);
    const auto &dst_node_type = get<2>(edge_type);
rusty1s's avatar
rusty1s committed
192

rusty1s's avatar
update  
rusty1s committed
193
194
    const auto *colptr_data = kv.value().data_ptr<int64_t>();
    const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
195

rusty1s's avatar
rusty1s committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    const auto &dst_nodes = nodes_dict.at(dst_node_type);
    const auto &to_local_src_node = to_local_node_dict.at(src_node_type);

    vector<int64_t> rows, cols, edges;
    for (int64_t i = 0; i < (int64_t)dst_nodes.size(); i++) {
      const auto &w = dst_nodes[i];
      const auto &col_start = colptr_data[w], &col_end = colptr_data[w + 1];
      if (col_end - col_start > MAX_NEIGHBORS) {
        auto indices = choice(col_end - col_start, MAX_NEIGHBORS, false);
        auto *indices_data = indices.data_ptr<int64_t>();
        for (int64_t j = 0; j < indices.numel(); j++) {
          const auto &v = row_data[col_start + indices_data[j]];
          if (to_local_src_node.find(v) != to_local_src_node.end()) {
            rows.push_back(to_local_src_node.at(v));
            cols.push_back(i);
            edges.push_back(col_start + j);
          }
        }
      } else {
        for (int64_t j = col_start; j < col_end; j++) {
          const auto &v = row_data[j];
          if (to_local_src_node.find(v) != to_local_src_node.end()) {
            rows.push_back(to_local_src_node.at(v));
            cols.push_back(i);
            edges.push_back(j);
          }
rusty1s's avatar
rusty1s committed
222
223
224
        }
      }
    }
rusty1s's avatar
update  
rusty1s committed
225
    if (rows.size() > 0) {
rusty1s's avatar
rusty1s committed
226
227
228
      out_row_dict.insert(rel_type, from_vector<int64_t>(rows));
      out_col_dict.insert(rel_type, from_vector<int64_t>(cols));
      out_edge_dict.insert(rel_type, from_vector<int64_t>(edges));
rusty1s's avatar
update  
rusty1s committed
229
    }
rusty1s's avatar
rusty1s committed
230
231
  }

rusty1s's avatar
rusty1s committed
232
233
234
235
236
237
  // Generate tensor-valued output node dictionary (line 20):
  for (const auto &kv : nodes_dict) {
    const auto &node_type = kv.first;
    const auto &nodes = kv.second;
    if (!nodes.empty())
      out_node_dict.insert(node_type, from_vector<int64_t>(nodes));
rusty1s's avatar
rusty1s committed
238
239
  }

rusty1s's avatar
rusty1s committed
240
  return make_tuple(out_node_dict, out_row_dict, out_col_dict, out_edge_dict);
rusty1s's avatar
rusty1s committed
241
}