neighbor_sample_cpu.cpp 19.6 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include "neighbor_sample_cpu.h"

#include "utils.h"

rusty1s's avatar
rusty1s committed
5
6
7
8
#ifdef _WIN32
#include <process.h>
#endif

rusty1s's avatar
rusty1s committed
9
10
11
12
using namespace std;

namespace {

13
14
typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
       const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {

  // Initialize some data structures for the sampling process:
  vector<int64_t> samples;
22
  phmap::flat_hash_map<int64_t, int64_t> to_local_node;
rusty1s's avatar
rusty1s committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

  auto *colptr_data = colptr.data_ptr<int64_t>();
  auto *row_data = row.data_ptr<int64_t>();
  auto *input_node_data = input_node.data_ptr<int64_t>();

  for (int64_t i = 0; i < input_node.numel(); i++) {
    const auto &v = input_node_data[i];
    samples.push_back(v);
    to_local_node.insert({v, i});
  }

  vector<int64_t> rows, cols, edges;

  int64_t begin = 0, end = samples.size();
  for (int64_t ell = 0; ell < (int64_t)num_neighbors.size(); ell++) {
    const auto &num_samples = num_neighbors[ell];
    for (int64_t i = begin; i < end; i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      const auto col_count = col_end - col_start;

      if (col_count == 0)
        continue;

rusty1s's avatar
bugfix  
rusty1s committed
48
49
      if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
        for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
59
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
rusty1s's avatar
bugfix  
rusty1s committed
60
61
      } else if (replace) {
        for (int64_t j = 0; j < num_samples; j++) {
62
          const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
63
64
65
66
67
68
69
70
71
72
73
74
75
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      } else {
        unordered_set<int64_t> rnd_indices;
        for (int64_t j = col_count - num_samples; j < col_count; j++) {
76
          int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
          if (!rnd_indices.insert(rnd).second) {
            rnd = j;
            rnd_indices.insert(j);
          }
          const int64_t offset = col_start + rnd;
          const int64_t &v = row_data[offset];
          const auto res = to_local_node.insert({v, samples.size()});
          if (res.second)
            samples.push_back(v);
          if (directed) {
            cols.push_back(i);
            rows.push_back(res.first->second);
            edges.push_back(offset);
          }
        }
      }
    }
    begin = end, end = samples.size();
  }

  if (!directed) {
98
    phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
      const auto &w = samples[i];
      const auto &col_start = colptr_data[w];
      const auto &col_end = colptr_data[w + 1];
      for (int64_t offset = col_start; offset < col_end; offset++) {
        const auto &v = row_data[offset];
        iter = to_local_node.find(v);
        if (iter != to_local_node.end()) {
          rows.push_back(iter->second);
          cols.push_back(i);
          edges.push_back(offset);
        }
      }
    }
  }

  return make_tuple(from_vector<int64_t>(samples), from_vector<int64_t>(rows),
                    from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}

119
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
120
121
                         const node_t &src_node_type, int64_t dst_time,
                         int64_t src_node) {
122
123
  try {
    // Check whether src -> dst obeys the time constraint
124
125
    const torch::Tensor &src_node_time = node_time_dict.at(src_node_type);
    return src_node_time.data_ptr<int64_t>()[src_node] <= dst_time;
126
127
  } catch (const std::out_of_range& e) {
    // If no time is given, fall back to normal sampling
Rex Ying's avatar
Rex Ying committed
128
129
130
131
132
    return true;
  }
}

template <bool replace, bool directed, bool temporal>
rusty1s's avatar
bugfix  
rusty1s committed
133
134
135
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
Matthias Fey's avatar
Matthias Fey committed
136
137
138
139
140
              const vector<edge_t> &edge_types,
              const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
              const c10::Dict<rel_t, torch::Tensor> &row_dict,
              const c10::Dict<node_t, torch::Tensor> &input_node_dict,
              const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
141
142
143
              const c10::Dict<node_t, torch::Tensor> &node_time_dict,
              const int64_t num_hops) {

rusty1s's avatar
rusty1s committed
144
  // Create a mapping to convert single string relations to edge type triplets:
145
  phmap::flat_hash_map<rel_t, edge_t> to_edge_type;
rusty1s's avatar
rusty1s committed
146
147
148
149
  for (const auto &k : edge_types)
    to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;

  // Initialize some data structures for the sampling process:
150
  phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
151
  phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
152
  phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
153
  phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
154
  phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
Rex Ying's avatar
Rex Ying committed
155
156
  for (const auto &node_type : node_types) {
    samples_dict[node_type];
157
    temp_samples_dict[node_type];
Rex Ying's avatar
Rex Ying committed
158
    to_local_node_dict[node_type];
159
    temp_to_local_node_dict[node_type];
Rex Ying's avatar
Rex Ying committed
160
161
162
    root_time_dict[node_type];
  }

163
  phmap::flat_hash_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
164
165
166
167
168
169
170
  for (const auto &kv : colptr_dict) {
    const auto &rel_type = kv.key();
    rows_dict[rel_type];
    cols_dict[rel_type];
    edges_dict[rel_type];
  }

rusty1s's avatar
rusty1s committed
171
172
173
  // Add the input nodes to the output nodes:
  for (const auto &kv : input_node_dict) {
    const auto &node_type = kv.key();
Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
174
    const torch::Tensor &input_node = kv.value();
rusty1s's avatar
rusty1s committed
175
    const auto *input_node_data = input_node.data_ptr<int64_t>();
176

Matthias Fey's avatar
Matthias Fey committed
177
    int64_t *node_time_data;
Rex Ying's avatar
Rex Ying committed
178
    if (temporal) {
179
      const torch::Tensor &node_time = node_time_dict.at(node_type);
Matthias Fey's avatar
Matthias Fey committed
180
      node_time_data = node_time.data_ptr<int64_t>();
Rex Ying's avatar
Rex Ying committed
181
    }
rusty1s's avatar
rusty1s committed
182
183

    auto &samples = samples_dict.at(node_type);
184
    auto &temp_samples = temp_samples_dict.at(node_type);
rusty1s's avatar
rusty1s committed
185
    auto &to_local_node = to_local_node_dict.at(node_type);
186
    auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
Rex Ying's avatar
Rex Ying committed
187
    auto &root_time = root_time_dict.at(node_type);
rusty1s's avatar
rusty1s committed
188
189
    for (int64_t i = 0; i < input_node.numel(); i++) {
      const auto &v = input_node_data[i];
190
191
192
193
194
195
196
      if (temporal) {
        temp_samples.push_back({v, i});
        temp_to_local_node.insert({{v, i}, i});
      } else {
        samples.push_back(v);
        to_local_node.insert({v, i});
      }
197
      if (temporal)
Rex Ying's avatar
Rex Ying committed
198
        root_time.push_back(node_time_data[v]);
rusty1s's avatar
rusty1s committed
199
200
201
    }
  }

202
  phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
203
204
205
206
207
208
209
210
  if (temporal) {
    for (const auto &kv : temp_samples_dict) {
      slice_dict[kv.first] = {0, kv.second.size()};
    }
  } else {
    for (const auto &kv : samples_dict)
      slice_dict[kv.first] = {0, kv.second.size()};
  }
rusty1s's avatar
rusty1s committed
211

212
213
214
215
216
217
  vector<rel_t> all_rel_types;
  for (const auto &kv : num_neighbors_dict) {
    all_rel_types.push_back(kv.key());
  }
  std::sort(all_rel_types.begin(), all_rel_types.end());

rusty1s's avatar
rusty1s committed
218
  for (int64_t ell = 0; ell < num_hops; ell++) {
219
    for (const auto &rel_type : all_rel_types) {
rusty1s's avatar
rusty1s committed
220
221
222
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
223
      const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
rusty1s's avatar
rusty1s committed
224
      const auto &dst_samples = samples_dict.at(dst_node_type);
225
      const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
rusty1s's avatar
rusty1s committed
226
      auto &src_samples = samples_dict.at(src_node_type);
227
      auto &temp_src_samples = temp_samples_dict.at(src_node_type);
rusty1s's avatar
rusty1s committed
228
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);
229
      auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);
rusty1s's avatar
rusty1s committed
230

231
232
233
234
      const torch::Tensor &colptr = colptr_dict.at(rel_type);
      const auto *colptr_data = colptr.data_ptr<int64_t>();
      const torch::Tensor &row = row_dict.at(rel_type);
      const auto *row_data = row.data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
235
236
237
238
239

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

240
      // For temporal sampling, sampled nodes cannot have a timestamp greater
241
      // than the timestamp of the root nodes:
Rex Ying's avatar
Rex Ying committed
242
243
244
      const auto &dst_root_time = root_time_dict.at(dst_node_type);
      auto &src_root_time = root_time_dict.at(src_node_type);

245
246
      const auto &begin = slice_dict.at(dst_node_type).first;
      const auto &end = slice_dict.at(dst_node_type).second;
rusty1s's avatar
rusty1s committed
247
      for (int64_t i = begin; i < end; i++) {
248
249
        const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
        const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
250
251
252
        int64_t dst_time = 0;
        if (temporal)
          dst_time = dst_root_time[i];
rusty1s's avatar
rusty1s committed
253
254
255
256
257
258
259
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        const auto col_count = col_end - col_start;

        if (col_count == 0)
          continue;

rusty1s's avatar
bugfix  
rusty1s committed
260
        if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
261
          // Select all neighbors:
rusty1s's avatar
bugfix  
rusty1s committed
262
          for (int64_t offset = col_start; offset < col_end; offset++) {
rusty1s's avatar
rusty1s committed
263
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
264
            if (temporal) {
265
266
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
267
              // force disjoint of computation tree based on source batch idx.
268
269
270
              // note that the sampling always needs to have directed=True
              // for temporal case
              // to_local_src_node is not used for temporal / directed case
271
272
273
274
275
276
              const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
              if (res.second) {
                temp_src_samples.push_back({v, root_w});
                src_root_time.push_back(dst_time);
              }

rusty1s's avatar
rusty1s committed
277
              cols.push_back(i);
278
              rows.push_back(res.first->second);
rusty1s's avatar
rusty1s committed
279
              edges.push_back(offset);
280
281
282
283
284
285
286
287
288
            } else {
              const auto res = to_local_src_node.insert({v, src_samples.size()});
              if (res.second)
                src_samples.push_back(v);
              if (directed) {
                cols.push_back(i);
                rows.push_back(res.first->second);
                edges.push_back(offset);
              }
rusty1s's avatar
rusty1s committed
289
290
            }
          }
rusty1s's avatar
bugfix  
rusty1s committed
291
        } else if (replace) {
292
          // Sample with replacement:
Rex Ying's avatar
Rex Ying committed
293
294
          int64_t num_neighbors = 0;
          while (num_neighbors < num_samples) {
295
            const int64_t offset = col_start + uniform_randint(col_count);
rusty1s's avatar
rusty1s committed
296
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
297
            if (temporal) {
298
299
300
              // TODO Infinity loop if no neighbor satisfies time constraint:
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
301
              // force disjoint of computation tree based on source batch idx.
302
303
              // note that the sampling always needs to have directed=True
              // for temporal case
304
305
306
307
308
309
              const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
              if (res.second) {
                temp_src_samples.push_back({v, root_w});
                src_root_time.push_back(dst_time);
              }

rusty1s's avatar
rusty1s committed
310
              cols.push_back(i);
311
              rows.push_back(res.first->second);
rusty1s's avatar
rusty1s committed
312
              edges.push_back(offset);
313
314
315
316
317
318
319
320
321
            } else {
              const auto res = to_local_src_node.insert({v, src_samples.size()});
              if (res.second)
                src_samples.push_back(v);
              if (directed) {
                cols.push_back(i);
                rows.push_back(res.first->second);
                edges.push_back(offset);
              }
rusty1s's avatar
rusty1s committed
322
            }
Rex Ying's avatar
Rex Ying committed
323
            num_neighbors += 1;
rusty1s's avatar
rusty1s committed
324
325
          }
        } else {
326
          // Sample without replacement:
rusty1s's avatar
rusty1s committed
327
328
          unordered_set<int64_t> rnd_indices;
          for (int64_t j = col_count - num_samples; j < col_count; j++) {
329
            int64_t rnd = uniform_randint(j);
rusty1s's avatar
rusty1s committed
330
331
332
333
334
335
            if (!rnd_indices.insert(rnd).second) {
              rnd = j;
              rnd_indices.insert(j);
            }
            const int64_t offset = col_start + rnd;
            const int64_t &v = row_data[offset];
Rex Ying's avatar
Rex Ying committed
336
            if (temporal) {
337
338
              if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
                continue;
339
              // force disjoint of computation tree based on source batch idx.
340
341
              // note that the sampling always needs to have directed=True
              // for temporal case
342
343
344
345
346
347
              const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
              if (res.second) {
                temp_src_samples.push_back({v, root_w});
                src_root_time.push_back(dst_time);
              }

rusty1s's avatar
rusty1s committed
348
              cols.push_back(i);
349
              rows.push_back(res.first->second);
rusty1s's avatar
rusty1s committed
350
              edges.push_back(offset);
351
352
353
354
355
356
357
358
359
            } else {
              const auto res = to_local_src_node.insert({v, src_samples.size()});
              if (res.second)
                src_samples.push_back(v);
              if (directed) {
                cols.push_back(i);
                rows.push_back(res.first->second);
                edges.push_back(offset);
              }
rusty1s's avatar
rusty1s committed
360
361
362
363
364
365
            }
          }
        }
      }
    }

366
367
368
369
370
371
372
    if (temporal) {
      for (const auto &kv : temp_samples_dict) {
        slice_dict[kv.first] = {0, kv.second.size()};
      }
    } else {
      for (const auto &kv : samples_dict)
        slice_dict[kv.first] = {0, kv.second.size()};
rusty1s's avatar
rusty1s committed
373
374
375
    }
  }

376
377
  // Temporal sample disable undirected
  assert(!(temporal && !directed));
rusty1s's avatar
rusty1s committed
378
  if (!directed) { // Construct the subgraph among the sampled nodes:
379
    phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
rusty1s's avatar
rusty1s committed
380
381
382
383
384
385
386
387
    for (const auto &kv : colptr_dict) {
      const auto &rel_type = kv.key();
      const auto &edge_type = to_edge_type[rel_type];
      const auto &src_node_type = get<0>(edge_type);
      const auto &dst_node_type = get<2>(edge_type);
      const auto &dst_samples = samples_dict.at(dst_node_type);
      auto &to_local_src_node = to_local_node_dict.at(src_node_type);

Michał Marcinkiewicz's avatar
Michał Marcinkiewicz committed
388
      const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
389
390
      const auto *row_data =
          ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

      auto &rows = rows_dict.at(rel_type);
      auto &cols = cols_dict.at(rel_type);
      auto &edges = edges_dict.at(rel_type);

      for (int64_t i = 0; i < (int64_t)dst_samples.size(); i++) {
        const auto &w = dst_samples[i];
        const auto &col_start = colptr_data[w];
        const auto &col_end = colptr_data[w + 1];
        for (int64_t offset = col_start; offset < col_end; offset++) {
          const auto &v = row_data[offset];
          iter = to_local_src_node.find(v);
          if (iter != to_local_src_node.end()) {
            rows.push_back(iter->second);
            cols.push_back(i);
            edges.push_back(offset);
          }
        }
      }
    }
  }

413
414
415
416
417
418
419
420
421
422
423
424
  // Construct samples dictionary from temporal sample dictionary.
  if (temporal) {
    for (const auto &kv : temp_samples_dict) {
      const auto &node_type = kv.first;
      const auto &samples = kv.second;
      samples_dict[node_type].reserve(samples.size());
      for (const auto &v : samples) {
        samples_dict[node_type].push_back(v.first);
      }
    }
  }

rusty1s's avatar
rusty1s committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
  return make_tuple(from_vector<node_t, int64_t>(samples_dict),
                    from_vector<rel_t, int64_t>(rows_dict),
                    from_vector<rel_t, int64_t>(cols_dict),
                    from_vector<rel_t, int64_t>(edges_dict));
}

} // namespace

tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
                    const torch::Tensor &input_node,
                    const vector<int64_t> num_neighbors, const bool replace,
                    const bool directed) {

  if (replace && directed) {
    return sample<true, true>(colptr, row, input_node, num_neighbors);
  } else if (replace && !directed) {
    return sample<true, false>(colptr, row, input_node, num_neighbors);
  } else if (!replace && directed) {
    return sample<false, true>(colptr, row, input_node, num_neighbors);
  } else {
    return sample<false, false>(colptr, row, input_node, num_neighbors);
  }
}

rusty1s's avatar
bugfix  
rusty1s committed
450
451
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
rusty1s's avatar
rusty1s committed
452
hetero_neighbor_sample_cpu(
rusty1s's avatar
bugfix  
rusty1s committed
453
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
rusty1s's avatar
rusty1s committed
454
455
456
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
rusty1s's avatar
bugfix  
rusty1s committed
457
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
rusty1s's avatar
rusty1s committed
458
459
    const int64_t num_hops, const bool replace, const bool directed) {

460
461
  c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.

rusty1s's avatar
rusty1s committed
462
  if (replace && directed) {
463
464
465
    return hetero_sample<true, true, false>(
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
466
  } else if (replace && !directed) {
467
    return hetero_sample<true, false, false>(
Matthias Fey's avatar
Matthias Fey committed
468
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
469
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
470
  } else if (!replace && directed) {
471
    return hetero_sample<false, true, false>(
Matthias Fey's avatar
Matthias Fey committed
472
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
473
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
474
  } else {
475
    return hetero_sample<false, false, false>(
Matthias Fey's avatar
Matthias Fey committed
476
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
477
        num_neighbors_dict, node_time_dict, num_hops);
Rex Ying's avatar
Rex Ying committed
478
479
480
481
482
  }
}

tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
      c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
483
hetero_temporal_neighbor_sample_cpu(
Rex Ying's avatar
Rex Ying committed
484
485
486
487
488
489
490
    const vector<node_t> &node_types, const vector<edge_t> &edge_types,
    const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
    const c10::Dict<rel_t, torch::Tensor> &row_dict,
    const c10::Dict<node_t, torch::Tensor> &input_node_dict,
    const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
    const c10::Dict<node_t, torch::Tensor> &node_time_dict,
    const int64_t num_hops, const bool replace, const bool directed) {
491
  AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling");
492
493
494
495
496
497
498
  if (replace) {
    // We assume that directed = True for temporal sampling
    // The current implementation uses disjoint computation trees
    // to tackle the case of the same node sampled having different
    // root time constraint.
    // In future, we could extend to directed = False case,
    // allowing additional edges within each computation tree.
Rex Ying's avatar
Rex Ying committed
499
    return hetero_sample<true, true, true>(
Matthias Fey's avatar
Matthias Fey committed
500
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
501
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
502
  } else {
503
    return hetero_sample<false, true, true>(
Matthias Fey's avatar
Matthias Fey committed
504
        node_types, edge_types, colptr_dict, row_dict, input_node_dict,
505
        num_neighbors_dict, node_time_dict, num_hops);
rusty1s's avatar
rusty1s committed
506
  }
507
}