neighbor_sample_cpu.cpp 17.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "neighbor_sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using namespace std;

namespace {

template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
       const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {

  // Initialize some data structures for the sampling process:
  vector<int64_t> samples;
  unordered_map<int64_t, int64_t> to_local_node;

  auto *colptr_data = colptr.data_ptr<int64_t>();
  auto *row_data = row.data_ptr<int64_t>();
  auto *input_node_data = input_node.data_ptr<int64_t>();

  for (int64_t i = 0; i < input_node.numel(); i++) {
    const auto &v = input_node_data[i];
    samples.push_back(v);
    to_local_node.insert({v, i});
  }

  vector<int64_t> rows, cols, edges;

  int64_t begin = 0, end = samples.size();
  for (int64_t ell = 0; ell < (int64_t)num_neighbors.size(); ell++) {
    const auto &num_samples = num_neighbors[ell];
    for (int64_t i = begin; i < end; i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      const auto col_count = col_end - col_start;

      if (col_count == 0)
        continue;

rusty1s's avatar
bugfix  
rusty1s committed
46
47
      if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
        for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
56
57
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
rusty1s's avatar
bugfix  
rusty1s committed
58
59
      } else if (replace) {
        for (int64_t j = 0; j < num_samples; j++) {
60
          const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68
69
70
71
72
73
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      } else {
        unordered_set<int64_t> rnd_indices;
        for (int64_t j = col_count - num_samples; j < col_count; j++) {
74
          int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
          if (!rnd_indices.insert(rnd).second) {
            rnd = j;
            rnd_indices.insert(j);
          }
          const int64_t offset = col_start + rnd;
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      }
    }
    begin = end, end = samples.size();
  }

  if (!directed) {
    unordered_map<int64_t, int64_t>::iterator iter;
    for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      for (int64_t offset = col_start; offset < col_end; offset++) {
        const auto &v = row_data[offset];
        iter = to_local_node.find(v);
        if (iter != to_local_node.end()) {
          rows.push_back(iter->second);
          cols.push_back(i);
          edges.push_back(offset);
        }
      }
    }
  }

  return make_tuple(from_vector<int64_t>(samples), from_vector<int64_t>(rows),
                    from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}

Rex Ying's avatar
Rex Ying committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
                             const std::string &src_node_type,
                             const int64_t &dst_time,
                             const int64_t &sampled_node) {
  // whether src -> dst obeys the time constraint
  try {
    const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
    return dst_time < src_time[sampled_node];
  }
  catch (int err) {
    // if the node type does not have timestamp, fall back to normal sampling
    return true;
  }
}


template <bool replace, bool directed, bool temporal>
rusty1s's avatar
bugfix  
rusty1s committed
134
135
136
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
Rex Ying's avatar
Rex Ying committed
137
138
139
140
141
142
143
144
                       const vector<edge_t> &edge_types,
                       const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
                       const c10::Dict<rel_t, torch::Tensor> &row_dict,
                       const c10::Dict<node_t, torch::Tensor> &input_node_dict,
                       const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
                       const int64_t num_hops,
                       const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
  //bool temporal = (!node_time_dict.empty());
rusty1s's avatar
rusty1s committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

  // Create a mapping to convert single string relations to edge type triplets:
  unordered_map<rel_t, edge_t> to_edge_type;
  for (const auto &k : edge_types)
    to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;

  // Initialize some data structures for the sampling process:
  unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
  for (const auto &kv : colptr_dict) {
    const auto &rel_type = kv.key();
    rows_dict[rel_type];
    cols_dict[rel_type];
    edges_dict[rel_type];
  }

Rex Ying's avatar
Rex Ying committed
160
161
162
163
164
165
166
167
168
169
170
  unordered_map<node_t, vector<int64_t>> samples_dict;
  unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
  // The timestamp of the center node whose neighborhood that the sampled node
  // belongs to. It maps node_type to empty vector in non-temporal sampling.
  unordered_map<node_t, vector<int64_t>> root_time_dict;
  for (const auto &node_type : node_types) {
    samples_dict[node_type];
    to_local_node_dict[node_type];
    root_time_dict[node_type];
  }

rusty1s's avatar
rusty1s committed
171
172
173
  // Add the input nodes to the output nodes:
  for (const auto &kv : input_node_dict) {
    const auto &node_type = kv.key();
Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
174
    const torch::Tensor &input_node = kv.value();
rusty1s's avatar
rusty1s committed
175
    const auto *input_node_data = input_node.data_ptr<int64_t>();
Rex Ying's avatar
Rex Ying committed
176
177
178
179
180
181
182
    // dummy value. will be reset to root time if is_temporal==true
    auto *node_time_data = input_node.data_ptr<int64_t>();
    // root_time[i] stores the timestamp of the computation tree root
    // of the node samples[i]
    if (temporal) {
      node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>();
    }
rusty1s's avatar
rusty1s committed
183
184
185

    auto &samples = samples_dict.at(node_type);
    auto &to_local_node = to_local_node_dict.at(node_type);
Rex Ying's avatar
Rex Ying committed
186
    auto &root_time = root_time_dict.at(node_type);
rusty1s's avatar
rusty1s committed
187
188
189
190
    for (int64_t i = 0; i < input_node.numel(); i++) {
      const auto &v = input_node_data[i];
      samples.push_back(v);
      to_local_node.insert({v, i});
Rex Ying's avatar
Rex Ying committed
191
192
193
      if (temporal) {
        root_time.push_back(node_time_data[v]);
      }
rusty1s's avatar
rusty1s committed
194
195
196
197
198
199
200
201
202
203
204
205
206
    }
  }

  unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
  for (const auto &kv : samples_dict)
    slice_dict[kv.first] = {0, kv.second.size()};

  for (int64_t ell = 0; ell < num_hops; ell++) {
    for (const auto &kv : num_neighbors_dict) {
      const auto &rel_type = kv.key();
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
rusty1s's avatar
bugfix  
rusty1s committed
207
      const auto num_samples = kv.value()[ell];
rusty1s's avatar
rusty1s committed
208
209
210
211
      const auto &dst_samples = samples_dict.at(dst_node_type);
      auto &src_samples = samples_dict.at(src_node_type);
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);

212
213
214
215
      const auto *colptr_data =
          ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
      const auto *row_data =
          ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
216
217
218
219
220
221
222

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

      const auto &begin = slice_dict.at(dst_node_type).first;
      const auto &end = slice_dict.at(dst_node_type).second;
Rex Ying's avatar
Rex Ying committed
223
224
225
226
227
228
229
230
      if (begin == end){
        continue;
      }
      // for temporal sampling, sampled src node cannot have timestamp greater
      // than its corresponding dst_root_time
      const auto &dst_root_time = root_time_dict.at(dst_node_type);
      auto &src_root_time = root_time_dict.at(src_node_type);

rusty1s's avatar
rusty1s committed
231
232
      for (int64_t i = begin; i < end; i++) {
        const auto &w = dst_samples[i];
Rex Ying's avatar
Rex Ying committed
233
        const auto &dst_time = dst_root_time[i];
rusty1s's avatar
rusty1s committed
234
235
236
237
238
239
240
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        const auto col_count = col_end - col_start;

        if (col_count == 0)
          continue;

rusty1s's avatar
bugfix  
rusty1s committed
241
        if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
Rex Ying's avatar
Rex Ying committed
242
          // select all neighbors
rusty1s's avatar
bugfix  
rusty1s committed
243
          for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
244
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
245
246
247
248
249
250
251
            bool time_constraint = true;
            if (temporal) {
              time_constraint = satisfy_time_constraint(
                  node_time_dict, src_node_type, dst_time, v);
            }
            if (!time_constraint)
              continue;
rusty1s's avatar
rusty1s committed
252
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
253
            if (res.second) {
rusty1s's avatar
rusty1s committed
254
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
255
256
257
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
258
259
260
261
262
263
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
          }
rusty1s's avatar
bugfix  
rusty1s committed
264
        } else if (replace) {
Rex Ying's avatar
Rex Ying committed
265
266
267
          // sample with replacement
          int64_t num_neighbors = 0;
          while (num_neighbors < num_samples) {
268
            const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
269
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
270
271
272
273
274
275
276
            bool time_constraint = true;
            if (temporal) {
              time_constraint = satisfy_time_constraint(
                  node_time_dict, src_node_type, dst_time, v);
            }
            if (!time_constraint)
              continue;
rusty1s's avatar
rusty1s committed
277
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
278
            if (res.second) {
rusty1s's avatar
rusty1s committed
279
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
280
281
282
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
283
284
285
286
287
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
Rex Ying's avatar
Rex Ying committed
288
            num_neighbors += 1;
rusty1s's avatar
rusty1s committed
289
290
          }
        } else {
Rex Ying's avatar
Rex Ying committed
291
          // sample without replacement
rusty1s's avatar
rusty1s committed
292
293
          unordered_set<int64_t> rnd_indices;
          for (int64_t j = col_count - num_samples; j < col_count; j++) {
294
            int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
295
296
297
298
299
300
            if (!rnd_indices.insert(rnd).second) {
              rnd = j;
              rnd_indices.insert(j);
            }
            const int64_t offset = col_start + rnd;
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
301
302
303
304
305
306
307
            bool time_constraint = true;
            if (temporal) {
              time_constraint = satisfy_time_constraint(
                  node_time_dict, src_node_type, dst_time, v);
            }
            if (!time_constraint)
              continue;
rusty1s's avatar
rusty1s committed
308
            const auto res = to_local_src_node.insert({v, src_samples.size()});
Rex Ying's avatar
Rex Ying committed
309
            if (res.second) {
rusty1s's avatar
rusty1s committed
310
              src_samples.push_back(v);
Rex Ying's avatar
Rex Ying committed
311
312
313
              if (temporal)
                src_root_time.push_back(dst_time);
            }
rusty1s's avatar
rusty1s committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            if (directed) {
              cols.push_back(i);
              rows.push_back(res.first->second);
              edges.push_back(offset);
            }
          }
        }
      }
    }

    for (const auto &kv : samples_dict) {
      slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
    }
  }

  if (!directed) { // Construct the subgraph among the sampled nodes:
    unordered_map<int64_t, int64_t>::iterator iter;
    for (const auto &kv : colptr_dict) {
      const auto &rel_type = kv.key();
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
      const auto &dst_samples = samples_dict.at(dst_node_type);
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);

Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
339
      const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
340
341
      const auto *row_data =
          ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

      for (int64_t i = 0; i < (int64_t)dst_samples.size(); i++) {
        const auto &w = dst_samples[i];
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        for (int64_t offset = col_start; offset < col_end; offset++) {
          const auto &v = row_data[offset];
          iter = to_local_src_node.find(v);
          if (iter != to_local_src_node.end()) {
            rows.push_back(iter->second);
            cols.push_back(i);
            edges.push_back(offset);
          }
        }
      }
    }
  }

  return make_tuple(from_vector<node_t, int64_t>(samples_dict),
                    from_vector<rel_t, int64_t>(rows_dict),
                    from_vector<rel_t, int64_t>(cols_dict),
                    from_vector<rel_t, int64_t>(edges_dict));
}

Rex Ying's avatar
Rex Ying committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(const vector<node_t> &node_types,
              const vector<edge_t> &edge_types,
              const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
              const c10::Dict<rel_t, torch::Tensor> &row_dict,
              const c10::Dict<node_t, torch::Tensor> &input_node_dict,
              const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
              const int64_t num_hops) {
  c10::Dict<node_t, torch::Tensor> empty_dict;
  return hetero_sample<replace, directed, false>(node_types,
              edge_types,
              colptr_dict,
              row_dict,
              input_node_dict,
              num_neighbors_dict,
              num_hops,
              empty_dict);
}

rusty1s's avatar
rusty1s committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
} // namespace

tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
                    const torch::Tensor &input_node,
                    const vector<int64_t> num_neighbors, const bool replace,
                    const bool directed) {

  if (replace && directed) {
    return sample<true, true>(colptr, row, input_node, num_neighbors);
  } else if (replace && !directed) {
    return sample<true, false>(colptr, row, input_node, num_neighbors);
  } else if (!replace && directed) {
    return sample<false, true>(colptr, row, input_node, num_neighbors);
  } else {
    return sample<false, false>(colptr, row, input_node, num_neighbors);
  }
}

rusty1s's avatar
bugfix  
rusty1s committed
410
411
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
rusty1s's avatar
rusty1s committed
412
hetero_neighbor_sample_cpu(
rusty1s's avatar
bugfix  
rusty1s committed
413
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
rusty1s's avatar
rusty1s committed
414
415
416
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
rusty1s's avatar
bugfix  
rusty1s committed
417
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
rusty1s's avatar
rusty1s committed
418
419
420
    const int64_t num_hops, const bool replace, const bool directed) {

  if (replace && directed) {
Rex Ying's avatar
Rex Ying committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    return hetero_sample_random<true, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops);
  } else if (replace && !directed) {
    return hetero_sample_random<true, false>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops);
  } else if (!replace && directed) {
    return hetero_sample_random<false, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops);
  } else {
    return hetero_sample_random<false, false>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops);
  }
}

tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_neighbor_temporal_sample_cpu(
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
    const c10::Dict<node_t, torch::Tensor> &node_time_dict,
    const int64_t num_hops, const bool replace, const bool directed) {

  if (replace && directed) {
    return hetero_sample<true, true, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops, node_time_dict);
rusty1s's avatar
rusty1s committed
459
  } else if (replace && !directed) {
Rex Ying's avatar
Rex Ying committed
460
461
462
463
    return hetero_sample<true, false, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops, node_time_dict);
rusty1s's avatar
rusty1s committed
464
  } else if (!replace && directed) {
Rex Ying's avatar
Rex Ying committed
465
466
467
468
    return hetero_sample<false, true, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops, node_time_dict);
rusty1s's avatar
rusty1s committed
469
  } else {
Rex Ying's avatar
Rex Ying committed
470
471
472
473
    return hetero_sample<false, false, true>(
        node_types, edge_types, colptr_dict,
        row_dict, input_node_dict,
        num_neighbors_dict, num_hops, node_time_dict);
rusty1s's avatar
rusty1s committed
474
475
  }
}