neighbor_sample_cpu.cpp 16.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "neighbor_sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using namespace std;

namespace {

template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
       const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {

  // Initialize some data structures for the sampling process:
  vector<int64_t> samples;
  unordered_map<int64_t, int64_t> to_local_node;

  auto *colptr_data = colptr.data_ptr<int64_t>();
  auto *row_data = row.data_ptr<int64_t>();
  auto *input_node_data = input_node.data_ptr<int64_t>();

  for (int64_t i = 0; i < input_node.numel(); i++) {
    const auto &v = input_node_data[i];
    samples.push_back(v);
    to_local_node.insert({v, i});
  }

  vector<int64_t> rows, cols, edges;

  int64_t begin = 0, end = samples.size();
  for (int64_t ell = 0; ell < (int64_t)num_neighbors.size(); ell++) {
    const auto &num_samples = num_neighbors[ell];
    for (int64_t i = begin; i < end; i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      const auto col_count = col_end - col_start;

      if (col_count == 0)
        continue;

rusty1s's avatar
bugfix  
rusty1s committed
46
47
      if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
        for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
rusty1s's avatar
bugfix  
rusty1s committed
58
59
      } else if (replace) {
        for (int64_t j = 0; j < num_samples; j++) {
60
          const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68
69
70
71
72
73
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      } else {
        unordered_set<int64_t> rnd_indices;
        for (int64_t j = col_count - num_samples; j < col_count; j++) {
74
          int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
          if (!rnd_indices.insert(rnd).second) {
            rnd = j;
            rnd_indices.insert(j);
          }
          const int64_t offset = col_start + rnd;
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      }
    }
    begin = end, end = samples.size();
  }

  if (!directed) {
    unordered_map<int64_t, int64_t>::iterator iter;
    for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      for (int64_t offset = col_start; offset < col_end; offset++) {
        const auto &v = row_data[offset];
        iter = to_local_node.find(v);
        if (iter != to_local_node.end()) {
          rows.push_back(iter->second);
          cols.push_back(i);
          edges.push_back(offset);
        }
      }
    }
  }

  return make_tuple(from_vector<int64_t>(samples), from_vector<int64_t>(rows),
                    from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}

117
118
119
120
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
                         const node_t &src_node_type, const int64_t &dst_time,
                         const int64_t &src_node) {
  try { // Check whether src -> dst obeys the time constraint:
Matthias Fey's avatar
Matthias Fey committed
121
122
    auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
    return dst_time < src_time[src_node];
123
  } catch (int err) { // If no time is given, fall back to normal sampling:
Rex Ying's avatar
Rex Ying committed
124
125
126
127
128
    return true;
  }
}

template <bool replace, bool directed, bool temporal>
rusty1s's avatar
bugfix  
rusty1s committed
129
130
131
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
Matthias Fey's avatar
Matthias Fey committed
132
133
134
135
136
              const vector<edge_t> &edge_types,
              const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
              const c10::Dict<rel_t, torch::Tensor> &row_dict,
              const c10::Dict<node_t, torch::Tensor> &input_node_dict,
              const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
137
138
139
              const c10::Dict<node_t, torch::Tensor> &node_time_dict,
              const int64_t num_hops) {

rusty1s's avatar
rusty1s committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  // Create a mapping to convert single string relations to edge type triplets:
  unordered_map<rel_t, edge_t> to_edge_type;
  for (const auto &k : edge_types)
    to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;

  // Initialize some data structures for the sampling process:
  unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
  for (const auto &kv : colptr_dict) {
    const auto &rel_type = kv.key();
    rows_dict[rel_type];
    cols_dict[rel_type];
    edges_dict[rel_type];
  }

Rex Ying's avatar
Rex Ying committed
154
155
156
157
158
159
160
161
162
  unordered_map<node_t, vector<int64_t>> samples_dict;
  unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
  unordered_map<node_t, vector<int64_t>> root_time_dict;
  for (const auto &node_type : node_types) {
    samples_dict[node_type];
    to_local_node_dict[node_type];
    root_time_dict[node_type];
  }

rusty1s's avatar
rusty1s committed
163
164
165
  // Add the input nodes to the output nodes:
  for (const auto &kv : input_node_dict) {
    const auto &node_type = kv.key();
Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
166
    const torch::Tensor &input_node = kv.value();
rusty1s's avatar
rusty1s committed
167
    const auto *input_node_data = input_node.data_ptr<int64_t>();
Matthias Fey's avatar
Matthias Fey committed
168
    int64_t *node_time_data;
Rex Ying's avatar
Rex Ying committed
169
    if (temporal) {
Matthias Fey's avatar
Matthias Fey committed
170
171
      torch::Tensor node_time = node_time_dict.at(node_type);
      node_time_data = node_time.data_ptr<int64_t>();
Rex Ying's avatar
Rex Ying committed
172
    }
rusty1s's avatar
rusty1s committed
173
174
175

    auto &samples = samples_dict.at(node_type);
    auto &to_local_node = to_local_node_dict.at(node_type);
Rex Ying's avatar
Rex Ying committed
176
    auto &root_time = root_time_dict.at(node_type);
rusty1s's avatar
rusty1s committed
177
178
179
180
    for (int64_t i = 0; i < input_node.numel(); i++) {
      const auto &v = input_node_data[i];
      samples.push_back(v);
      to_local_node.insert({v, i});
181
      if (temporal)
Rex Ying's avatar
Rex Ying committed
182
        root_time.push_back(node_time_data[v]);
rusty1s's avatar
rusty1s committed
183
184
185
186
187
188
189
190
191
192
193
194
195
    }
  }

  unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
  for (const auto &kv : samples_dict)
    slice_dict[kv.first] = {0, kv.second.size()};

  for (int64_t ell = 0; ell < num_hops; ell++) {
    for (const auto &kv : num_neighbors_dict) {
      const auto &rel_type = kv.key();
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
rusty1s's avatar
bugfix  
rusty1s committed
196
      const auto num_samples = kv.value()[ell];
rusty1s's avatar
rusty1s committed
197
198
199
200
      const auto &dst_samples = samples_dict.at(dst_node_type);
      auto &src_samples = samples_dict.at(src_node_type);
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);

201
202
203
204
      const auto *colptr_data =
          ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
      const auto *row_data =
          ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
205
206
207
208
209
210
211

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

      const auto &begin = slice_dict.at(dst_node_type).first;
      const auto &end = slice_dict.at(dst_node_type).second;
212
213

      if (begin == end)
Rex Ying's avatar
Rex Ying committed
214
        continue;
215
216
217

      // For temporal sampling, sampled nodes cannot have a timestamp greater
      // than the timestamp of the root nodes.
Rex Ying's avatar
Rex Ying committed
218
219
220
      const auto &dst_root_time = root_time_dict.at(dst_node_type);
      auto &src_root_time = root_time_dict.at(src_node_type);

rusty1s's avatar
rusty1s committed
221
222
      for (int64_t i = begin; i < end; i++) {
        const auto &w = dst_samples[i];
Rex Ying's avatar
Rex Ying committed
223
        const auto &dst_time = dst_root_time[i];
rusty1s's avatar
rusty1s committed
224
225
226
227
228
229
230
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        const auto col_count = col_end - col_start;

        if (col_count == 0)
          continue;

rusty1s's avatar
bugfix  
rusty1s committed
231
        if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
232
          // Select all neighbors:
rusty1s's avatar
bugfix  
rusty1s committed
233
          for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
234
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
235
            if (temporal) {
236
237
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
Rex Ying's avatar
Rex Ying committed
238
            }
rusty1s's avatar
rusty1s committed
239
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
240
            if (res.second) {
rusty1s's avatar
rusty1s committed
241
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
242
243
244
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
245
246
247
248
249
250
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
          }
rusty1s's avatar
bugfix  
rusty1s committed
251
        } else if (replace) {
252
          // Sample with replacement:
Rex Ying's avatar
Rex Ying committed
253
254
          int64_t num_neighbors = 0;
          while (num_neighbors < num_samples) {
255
            const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
256
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
257
            if (temporal) {
258
259
260
              // TODO Infinity loop if no neighbor satisfies time constraint:
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
Rex Ying's avatar
Rex Ying committed
261
            }
rusty1s's avatar
rusty1s committed
262
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
263
            if (res.second) {
rusty1s's avatar
rusty1s committed
264
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
265
266
267
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
268
269
270
271
272
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
Rex Ying's avatar
Rex Ying committed
273
            num_neighbors += 1;
rusty1s's avatar
rusty1s committed
274
275
          }
        } else {
276
          // Sample without replacement:
rusty1s's avatar
rusty1s committed
277
278
          unordered_set<int64_t> rnd_indices;
          for (int64_t j = col_count - num_samples; j < col_count; j++) {
279
            int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
280
281
282
283
284
285
            if (!rnd_indices.insert(rnd).second) {
              rnd = j;
              rnd_indices.insert(j);
            }
            const int64_t offset = col_start + rnd;
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
286
            if (temporal) {
287
288
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
Rex Ying's avatar
Rex Ying committed
289
            }
rusty1s's avatar
rusty1s committed
290
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
291
            if (res.second) {
rusty1s's avatar
rusty1s committed
292
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
293
294
295
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
          }
        }
      }
    }

    for (const auto &kv : samples_dict) {
      slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
    }
  }

  if (!directed) { // Construct the subgraph among the sampled nodes:
    unordered_map<int64_t, int64_t>::iterator iter;
    for (const auto &kv : colptr_dict) {
      const auto &rel_type = kv.key();
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
      const auto &dst_samples = samples_dict.at(dst_node_type);
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);

Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
321
      const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
322
323
      const auto *row_data =
          ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

      for (int64_t i = 0; i < (int64_t)dst_samples.size(); i++) {
        const auto &w = dst_samples[i];
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        for (int64_t offset = col_start; offset < col_end; offset++) {
          const auto &v = row_data[offset];
          iter = to_local_src_node.find(v);
          if (iter != to_local_src_node.end()) {
            rows.push_back(iter->second);
            cols.push_back(i);
            edges.push_back(offset);
          }
        }
      }
    }
  }

  return make_tuple(from_vector<node_t, int64_t>(samples_dict),
                    from_vector<rel_t, int64_t>(rows_dict),
                    from_vector<rel_t, int64_t>(cols_dict),
                    from_vector<rel_t, int64_t>(edges_dict));
}

} // namespace

tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
                    const torch::Tensor &input_node,
                    const vector<int64_t> num_neighbors, const bool replace,
                    const bool directed) {

  if (replace && directed) {
    return sample<true, true>(colptr, row, input_node, num_neighbors);
  } else if (replace && !directed) {
    return sample<true, false>(colptr, row, input_node, num_neighbors);
  } else if (!replace && directed) {
    return sample<false, true>(colptr, row, input_node, num_neighbors);
  } else {
    return sample<false, false>(colptr, row, input_node, num_neighbors);
  }
}

rusty1s's avatar
bugfix  
rusty1s committed
371
372
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
rusty1s's avatar
rusty1s committed
373
hetero_neighbor_sample_cpu(
rusty1s's avatar
bugfix  
rusty1s committed
374
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
rusty1s's avatar
rusty1s committed
375
376
377
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
rusty1s's avatar
bugfix  
rusty1s committed
378
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
rusty1s's avatar
rusty1s committed
379
380
    const int64_t num_hops, const bool replace, const bool directed) {

381
382
  c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.

rusty1s's avatar
rusty1s committed
383
  if (replace && directed) {
384
385
386
    return hetero_sample<true, true, false>(
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
387
  } else if (replace && !directed) {
388
    return hetero_sample<true, false, false>(
Matthias Fey's avatar
Matthias Fey committed
389
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
390
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
391
  } else if (!replace && directed) {
392
    return hetero_sample<false, true, false>(
Matthias Fey's avatar
Matthias Fey committed
393
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
394
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
395
  } else {
396
    return hetero_sample<false, false, false>(
Matthias Fey's avatar
Matthias Fey committed
397
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
398
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
399
400
401
402
403
  }
}

tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
404
hetero_temporal_neighbor_sample_cpu(
Rex Ying's avatar
Rex Ying committed
405
406
407
408
409
410
411
412
413
414
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
    const c10::Dict<node_t, torch::Tensor> &node_time_dict,
    const int64_t num_hops, const bool replace, const bool directed) {

  if (replace && directed) {
    return hetero_sample<true, true, true>(
Matthias Fey's avatar
Matthias Fey committed
415
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
416
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
417
  } else if (replace && !directed) {
Rex Ying's avatar
Rex Ying committed
418
    return hetero_sample<true, false, true>(
Matthias Fey's avatar
Matthias Fey committed
419
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
420
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
421
  } else if (!replace && directed) {
Rex Ying's avatar
Rex Ying committed
422
    return hetero_sample<false, true, true>(
Matthias Fey's avatar
Matthias Fey committed
423
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
424
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
425
  } else {
Rex Ying's avatar
Rex Ying committed
426
    return hetero_sample<false, false, true>(
Matthias Fey's avatar
Matthias Fey committed
427
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
428
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
429
430
  }
}