neighbor.cc 24.9 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020-2022 by Contributors
3
4
 * @file graph/sampling/neighbor.cc
 * @brief Definition of neighborhood-based sampler APIs.
5
6
7
 */

#include <dgl/array.h>
8
#include <dgl/aten/macro.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
#include <dgl/sampling/neighbor.h>
12

13
14
15
#include <tuple>
#include <utility>

16
17
18
19
20
21
22
23
24
#include "../../../c_api_common.h"
#include "../../unit_graph.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {
namespace sampling {

25
26
27
std::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges(
    const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges,
    const std::vector<FloatArray>* weights = nullptr) {
28
29
30
  HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
  std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
  std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());
31
  std::vector<FloatArray> remain_weights(hg_view->NumEdgeTypes());
32
33
34
35
36
37
38
39

  for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
    IdArray edge_ids = Range(
        0, sg.induced_edges[etype]->shape[0],
        sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);
    if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
      remain_edges[etype] = edge_ids;
      remain_induced_edges[etype] = sg.induced_edges[etype];
40
      if (weights) remain_weights[etype] = (*weights)[etype];
41
      continue;
42
    }
43
    ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      const auto dtype = weights && (*weights)[etype]->shape[0]
                             ? (*weights)[etype]->dtype
                             : DGLDataType{kDGLFloat, 8 * sizeof(float), 1};
      ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", {
        IdType* idx_data = edge_ids.Ptr<IdType>();
        IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
        FloatType* weights_data = weights && (*weights)[etype]->shape[0]
                                      ? (*weights)[etype].Ptr<FloatType>()
                                      : nullptr;
        const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
        std::sort(
            exclude_edges[etype].Ptr<IdType>(),
            exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
        const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
        IdType outId = 0;
        for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
          // the following binary search is the bottleneck, excluding weights
          // together with edges should almost be free.
          if (!std::binary_search(
                  exclude_edges_data, exclude_edges_data + exclude_edges_len,
                  induced_edges_data[i])) {
            induced_edges_data[outId] = induced_edges_data[i];
            idx_data[outId] = idx_data[i];
            if (weights_data) weights_data[outId] = weights_data[i];
            ++outId;
          }
70
        }
71
72
73
74
75
76
77
        remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
        remain_induced_edges[etype] =
            aten::IndexSelect(sg.induced_edges[etype], 0, outId);
        remain_weights[etype] =
            weights_data ? aten::IndexSelect((*weights)[etype], 0, outId)
                         : NullArray();
      });
78
79
80
81
    });
  }
  HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
  subg.induced_edges = std::move(remain_induced_edges);
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  return std::make_pair(subg, remain_weights);
}

std::pair<HeteroSubgraph, std::vector<FloatArray>> SampleLabors(
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& fanouts, EdgeDir dir,
    const std::vector<FloatArray>& prob,
    const std::vector<IdArray>& exclude_edges, const int importance_sampling,
    const IdArray random_seed, const std::vector<IdArray>& NIDs) {
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
      << "Number of node ID tensors must match the number of node types.";
  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
      << "Number of fanout values must match the number of edge types.";

  DGLContext ctx = aten::GetContextOf(nodes);

  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<FloatArray> subimportances(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
    const IdArray NIDs_ntype =
        NIDs[(dir == EdgeDir::kIn) ? src_vtype : dst_vtype];
    const int64_t num_nodes = nodes_ntype->shape[0];
    if (num_nodes == 0 || fanouts[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
      subrels[etype] = UnitGraph::Empty(
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), ctx);
      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
118
      subimportances[etype] = NullArray();
119
120
121
122
123
124
    } else {
      // sample from one relation graph
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      COOMatrix sampled_coo;
      FloatArray importances;
125
126
127
128
129
      const int64_t fanout =
          fanouts[etype] >= 0
              ? fanouts[etype]
              : std::max(
                    hg->NumVertices(dst_vtype), hg->NumVertices(src_vtype));
130
131
132
133
134
      switch (avail_fmt) {
        case SparseFormat::kCOO:
          if (dir == EdgeDir::kIn) {
            auto fs = aten::COOLaborSampling(
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
135
                fanout, prob[etype], importance_sampling, random_seed,
136
137
138
139
140
                NIDs_ntype);
            sampled_coo = aten::COOTranspose(fs.first);
            importances = fs.second;
          } else {
            std::tie(sampled_coo, importances) = aten::COOLaborSampling(
141
142
                hg->GetCOOMatrix(etype), nodes_ntype, fanout, prob[etype],
                importance_sampling, random_seed, NIDs_ntype);
143
144
145
146
147
148
          }
          break;
        case SparseFormat::kCSR:
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(
149
              hg->GetCSRMatrix(etype), nodes_ntype, fanout, prob[etype],
150
151
152
153
154
              importance_sampling, random_seed, NIDs_ntype);
          break;
        case SparseFormat::kCSC:
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(
155
              hg->GetCSCMatrix(etype), nodes_ntype, fanout, prob[etype],
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
              importance_sampling, random_seed, NIDs_ntype);
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
      subrels[etype] = UnitGraph::CreateFromCOO(
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
      subimportances[etype] = importances;
      induced_edges[etype] = sampled_coo.data;
    }
  }

  HeteroSubgraph ret;
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);

  if (!exclude_edges.empty())
    return ExcludeCertainEdges(ret, exclude_edges, &subimportances);

  return std::make_pair(ret, std::move(subimportances));
180
181
}

182
HeteroSubgraph SampleNeighbors(
183
184
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& fanouts, EdgeDir dir,
185
    const std::vector<NDArray>& prob_or_mask,
186
    const std::vector<IdArray>& exclude_edges, bool replace) {
187
188
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
189
      << "Number of node ID tensors must match the number of node types.";
190
  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
191
      << "Number of fanout values must match the number of edge types.";
192
  CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
193
      << "Number of probability tensors must match the number of edge types.";
194

195
  DGLContext ctx = aten::GetContextOf(nodes);
196

197
198
199
200
201
202
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
203
204
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
205
    const int64_t num_nodes = nodes_ntype->shape[0];
206

207
208
    if (num_nodes == 0 || fanouts[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
209
      subrels[etype] = UnitGraph::Empty(
210
211
212
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), ctx);
213
      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
214
    } else {
215
      COOMatrix sampled_coo;
216
      // sample from one relation graph
217
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
218
219
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      switch (avail_fmt) {
220
        case SparseFormat::kCOO:
221
222
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
223
224
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                fanouts[etype], prob_or_mask[etype], replace));
225
226
          } else {
            sampled_coo = aten::COORowWiseSampling(
227
228
                hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],
                prob_or_mask[etype], replace);
229
230
          }
          break;
231
        case SparseFormat::kCSR:
232
233
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
234
          sampled_coo = aten::CSRRowWiseSampling(
235
236
              hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
237
          break;
238
        case SparseFormat::kCSC:
239
240
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseSampling(
241
242
              hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
243
244
245
246
247
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
248

249
      subrels[etype] = UnitGraph::CreateFromCOO(
250
251
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
252
      induced_edges[etype] = sampled_coo.data;
253
254
255
256
    }
  }

  HeteroSubgraph ret;
257
258
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
259
260
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
261
  if (!exclude_edges.empty()) {
262
    return ExcludeCertainEdges(ret, exclude_edges).first;
263
  }
264
265
266
  return ret;
}

267
HeteroSubgraph SampleNeighborsEType(
268
    const HeteroGraphPtr hg, const IdArray nodes,
269
    const std::vector<int64_t>& eid2etype_offset,
270
271
    const std::vector<int64_t>& fanouts, EdgeDir dir,
    const std::vector<FloatArray>& prob, bool replace,
272
    bool rowwise_etype_sorted) {
273
  CHECK_EQ(1, hg->NumVertexTypes())
274
      << "SampleNeighborsEType only work with homogeneous graph";
275
  CHECK_EQ(1, hg->NumEdgeTypes())
276
      << "SampleNeighborsEType only work with homogeneous graph";
277
278
279
280
281
282
283

  std::vector<HeteroGraphPtr> subrels(1);
  std::vector<IdArray> induced_edges(1);
  const int64_t num_nodes = nodes->shape[0];
  dgl_type_t etype = 0;
  const dgl_type_t src_vtype = 0;
  const dgl_type_t dst_vtype = 0;
284
285
286
287
288
289
290
291
292
293
294

  bool same_fanout = true;
  int64_t fanout_value = fanouts[0];
  for (auto fanout : fanouts) {
    if (fanout != fanout_value) {
      same_fanout = false;
      break;
    }
  }

  if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
295
296
297
    subrels[etype] = UnitGraph::Empty(
        1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
        hg->DataType(), hg->Context());
298
299
    induced_edges[etype] = aten::NullArray();
  } else {
300
    COOMatrix sampled_coo;
301
302
    // sample from graph
    // the edge type is stored in etypes
303
    auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
304
305
306
307
308
    auto avail_fmt = hg->SelectFormat(etype, req_fmt);
    switch (avail_fmt) {
      case SparseFormat::kCOO:
        if (dir == EdgeDir::kIn) {
          sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
309
310
              aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,
              eid2etype_offset, fanouts, prob, replace));
311
312
        } else {
          sampled_coo = aten::COORowWisePerEtypeSampling(
313
314
              hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
              replace);
315
316
317
318
319
        }
        break;
      case SparseFormat::kCSR:
        CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
320
321
322
            hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
        break;
323
324
325
      case SparseFormat::kCSC:
        CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
326
327
            hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
328
329
330
331
332
333
334
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }

    subrels[etype] = UnitGraph::CreateFromCOO(
335
336
        1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,
        sampled_coo.col);
337
338
339
340
    induced_edges[etype] = sampled_coo.data;
  }

  HeteroSubgraph ret;
341
342
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
343
344
345
346
347
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

348
HeteroSubgraph SampleNeighborsTopk(
349
350
351
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& k, EdgeDir dir,
    const std::vector<FloatArray>& weight, bool ascending) {
352
353
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
354
      << "Number of node ID tensors must match the number of node types.";
355
  CHECK_EQ(k.size(), hg->NumEdgeTypes())
356
      << "Number of k values must match the number of edge types.";
357
  CHECK_EQ(weight.size(), hg->NumEdgeTypes())
358
      << "Number of weight tensors must match the number of edge types.";
359
360
361
362
363
364
365

  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
366
367
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
368
    const int64_t num_nodes = nodes_ntype->shape[0];
369
370
    if (num_nodes == 0 || k[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
371
      subrels[etype] = UnitGraph::Empty(
372
373
374
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), hg->Context());
375
      induced_edges[etype] = aten::NullArray();
376
377
    } else {
      // sample from one relation graph
378
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
379
380
381
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      COOMatrix sampled_coo;
      switch (avail_fmt) {
382
        case SparseFormat::kCOO:
383
384
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
385
386
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                k[etype], weight[etype], ascending));
387
388
          } else {
            sampled_coo = aten::COORowWiseTopk(
389
390
                hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],
                ascending);
391
392
          }
          break;
393
        case SparseFormat::kCSR:
394
395
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
396
          sampled_coo = aten::CSRRowWiseTopk(
397
398
              hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
399
          break;
400
        case SparseFormat::kCSC:
401
402
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseTopk(
403
404
              hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
405
406
407
408
409
410
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
      subrels[etype] = UnitGraph::CreateFromCOO(
411
412
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
413
      induced_edges[etype] = sampled_coo.data;
414
415
416
417
    }
  }

  HeteroSubgraph ret;
418
419
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
420
421
422
423
424
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

425
HeteroSubgraph SampleNeighborsBiased(
426
427
428
429
430
    const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,
    const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
    const bool replace) {
  CHECK_EQ(hg->NumEdgeTypes(), 1)
      << "Only homogeneous or bipartite graphs are supported";
431
432
433
434
435
436
  auto pair = hg->meta_graph()->FindEdge(0);
  const dgl_type_t src_vtype = pair.first;
  const dgl_type_t dst_vtype = pair.second;
  const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;

  // sanity check
437
438
  CHECK_EQ(tag_offset->ndim, 2)
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
439
  CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
440
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
441
  CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
442
      << "The sizes of tag_offset and bias are inconsistent";
443
444
445
446
447
448

  const int64_t num_nodes = nodes->shape[0];
  HeteroGraphPtr subrel;
  IdArray induced_edges;
  const dgl_type_t etype = 0;
  if (num_nodes == 0 || fanout == 0) {
449
450
    // Nothing to sample for this etype, create a placeholder relation graph
    subrel = UnitGraph::Empty(
451
        hg->GetRelationGraph(etype)->NumVertexTypes(),
452
453
454
455
456
457
458
459
        hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),
        hg->Context());
    induced_edges = aten::NullArray();
  } else {
    // sample from one relation graph
    const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
    const auto created_fmt = hg->GetCreatedFormats();
    COOMatrix sampled_coo;
460

461
462
463
464
    switch (req_fmt) {
      case CSR_CODE:
        CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
465
            hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);
466
467
468
469
        break;
      case CSC_CODE:
        CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
470
            hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);
471
472
473
474
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
475
    }
476
477
478
479
480
    subrel = UnitGraph::CreateFromCOO(
        hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
        sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
    induced_edges = sampled_coo.data;
  }
481
482

  HeteroSubgraph ret;
483
484
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
485
486
487
488
489
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = {induced_edges};
  return ret;
}

490
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      IdArray nodes = args[1];
      const std::vector<int64_t>& eid2etype_offset =
          ListValueToVector<int64_t>(args[2]);
      IdArray fanout = args[3];
      const std::string dir_str = args[4];
      const auto& prob = ListValueToVector<FloatArray>(args[5]);
      const bool replace = args[6];
      const bool rowwise_etype_sorted = args[7];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
      CHECK_INT64(fanout, "fanout");
      std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsEType(
          hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,
          rowwise_etype_sorted);
      *rv = HeteroSubgraphRef(subg);
    });
514

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
DGL_REGISTER_GLOBAL("sampling.labor._CAPI_DGLSampleLabors")
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray fanouts_array = args[2];
      const auto& fanouts = fanouts_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& prob = ListValueToVector<FloatArray>(args[4]);
      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
      const int importance_sampling = args[6];
      const IdArray random_seed = args[7];
      const auto& NIDs = ListValueToVector<IdArray>(args[8]);

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg_ptr(new HeteroSubgraph);

      auto&& subg_importances = sampling::SampleLabors(
          hg.sptr(), nodes, fanouts, dir, prob, exclude_edges,
          importance_sampling, random_seed, NIDs);
      *subg_ptr = subg_importances.first;
      List<Value> ret_val;
      ret_val.push_back(Value(subg_ptr));
      for (auto& imp : subg_importances.second)
        ret_val.push_back(Value(MakeValue(imp)));

      *rv = ret_val;
    });

546
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray fanouts_array = args[2];
      const auto& fanouts = fanouts_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);
      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighbors(
          hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);

      *rv = HeteroSubgraphRef(subg);
    });
567
568

DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray k_array = args[2];
      const auto& k = k_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& weight = ListValueToVector<FloatArray>(args[4]);
      const bool ascending = args[5];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsTopk(
          hg.sptr(), nodes, k, dir, weight, ascending);

      *rv = HeteroGraphRef(subg);
    });
588

589
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const IdArray nodes = args[1];
      const int64_t fanout = args[2];
      const NDArray bias = args[3];
      const NDArray tag_offset = args[4];
      const std::string dir_str = args[5];
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsBiased(
          hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);

      *rv = HeteroGraphRef(subg);
    });
609

610
611
}  // namespace sampling
}  // namespace dgl