neighbor.cc 25.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020-2022 by Contributors
3
4
 * @file graph/sampling/neighbor.cc
 * @brief Definition of neighborhood-based sampler APIs.
5
6
7
 */

#include <dgl/array.h>
8
#include <dgl/aten/macro.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
#include <dgl/sampling/neighbor.h>
12

13
14
15
#include <tuple>
#include <utility>

16
17
18
19
20
21
22
23
24
#include "../../../c_api_common.h"
#include "../../unit_graph.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {
namespace sampling {

25
26
27
std::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges(
    const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges,
    const std::vector<FloatArray>* weights = nullptr) {
28
29
30
  HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
  std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
  std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());
31
  std::vector<FloatArray> remain_weights(hg_view->NumEdgeTypes());
32
33
34
35
36
37
38
39

  for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
    IdArray edge_ids = Range(
        0, sg.induced_edges[etype]->shape[0],
        sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);
    if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
      remain_edges[etype] = edge_ids;
      remain_induced_edges[etype] = sg.induced_edges[etype];
40
      if (weights) remain_weights[etype] = (*weights)[etype];
41
      continue;
42
    }
43
    ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      const auto dtype = weights && (*weights)[etype]->shape[0]
                             ? (*weights)[etype]->dtype
                             : DGLDataType{kDGLFloat, 8 * sizeof(float), 1};
      ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", {
        IdType* idx_data = edge_ids.Ptr<IdType>();
        IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
        FloatType* weights_data = weights && (*weights)[etype]->shape[0]
                                      ? (*weights)[etype].Ptr<FloatType>()
                                      : nullptr;
        const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
        std::sort(
            exclude_edges[etype].Ptr<IdType>(),
            exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
        const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
        IdType outId = 0;
        for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
          // the following binary search is the bottleneck, excluding weights
          // together with edges should almost be free.
          if (!std::binary_search(
                  exclude_edges_data, exclude_edges_data + exclude_edges_len,
                  induced_edges_data[i])) {
            induced_edges_data[outId] = induced_edges_data[i];
            idx_data[outId] = idx_data[i];
            if (weights_data) weights_data[outId] = weights_data[i];
            ++outId;
          }
70
        }
71
72
73
74
75
76
77
        remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
        remain_induced_edges[etype] =
            aten::IndexSelect(sg.induced_edges[etype], 0, outId);
        remain_weights[etype] =
            weights_data ? aten::IndexSelect((*weights)[etype], 0, outId)
                         : NullArray();
      });
78
79
80
81
    });
  }
  HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
  subg.induced_edges = std::move(remain_induced_edges);
82
83
84
85
86
87
88
89
  return std::make_pair(subg, remain_weights);
}

std::pair<HeteroSubgraph, std::vector<FloatArray>> SampleLabors(
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& fanouts, EdgeDir dir,
    const std::vector<FloatArray>& prob,
    const std::vector<IdArray>& exclude_edges, const int importance_sampling,
90
91
    const IdArray random_seed, const float seed2_contribution,
    const std::vector<IdArray>& NIDs) {
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
      << "Number of node ID tensors must match the number of node types.";
  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
      << "Number of fanout values must match the number of edge types.";

  DGLContext ctx = aten::GetContextOf(nodes);

  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<FloatArray> subimportances(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
    const IdArray NIDs_ntype =
        NIDs[(dir == EdgeDir::kIn) ? src_vtype : dst_vtype];
    const int64_t num_nodes = nodes_ntype->shape[0];
    if (num_nodes == 0 || fanouts[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
      subrels[etype] = UnitGraph::Empty(
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), ctx);
      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
119
      subimportances[etype] = NullArray();
120
121
122
123
124
125
    } else {
      // sample from one relation graph
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      COOMatrix sampled_coo;
      FloatArray importances;
126
127
128
129
130
      const int64_t fanout =
          fanouts[etype] >= 0
              ? fanouts[etype]
              : std::max(
                    hg->NumVertices(dst_vtype), hg->NumVertices(src_vtype));
131
132
133
134
135
      switch (avail_fmt) {
        case SparseFormat::kCOO:
          if (dir == EdgeDir::kIn) {
            auto fs = aten::COOLaborSampling(
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
136
                fanout, prob[etype], importance_sampling, random_seed,
137
                seed2_contribution, NIDs_ntype);
138
139
140
141
            sampled_coo = aten::COOTranspose(fs.first);
            importances = fs.second;
          } else {
            std::tie(sampled_coo, importances) = aten::COOLaborSampling(
142
                hg->GetCOOMatrix(etype), nodes_ntype, fanout, prob[etype],
143
144
                importance_sampling, random_seed, seed2_contribution,
                NIDs_ntype);
145
146
147
148
149
150
          }
          break;
        case SparseFormat::kCSR:
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(
151
              hg->GetCSRMatrix(etype), nodes_ntype, fanout, prob[etype],
152
              importance_sampling, random_seed, seed2_contribution, NIDs_ntype);
153
154
155
156
          break;
        case SparseFormat::kCSC:
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          std::tie(sampled_coo, importances) = aten::CSRLaborSampling(
157
              hg->GetCSCMatrix(etype), nodes_ntype, fanout, prob[etype],
158
              importance_sampling, random_seed, seed2_contribution, NIDs_ntype);
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
      subrels[etype] = UnitGraph::CreateFromCOO(
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
      subimportances[etype] = importances;
      induced_edges[etype] = sampled_coo.data;
    }
  }

  HeteroSubgraph ret;
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);

  if (!exclude_edges.empty())
    return ExcludeCertainEdges(ret, exclude_edges, &subimportances);

  return std::make_pair(ret, std::move(subimportances));
182
183
}

184
HeteroSubgraph SampleNeighbors(
185
186
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& fanouts, EdgeDir dir,
187
    const std::vector<NDArray>& prob_or_mask,
188
    const std::vector<IdArray>& exclude_edges, bool replace) {
189
190
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
191
      << "Number of node ID tensors must match the number of node types.";
192
  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
193
      << "Number of fanout values must match the number of edge types.";
194
  CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
195
      << "Number of probability tensors must match the number of edge types.";
196

197
  DGLContext ctx = aten::GetContextOf(nodes);
198

199
200
201
202
203
204
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
205
206
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
207
    const int64_t num_nodes = nodes_ntype->shape[0];
208

209
210
    if (num_nodes == 0 || fanouts[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
211
      subrels[etype] = UnitGraph::Empty(
212
213
214
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), ctx);
215
      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
216
    } else {
217
      COOMatrix sampled_coo;
218
      // sample from one relation graph
219
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
220
221
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      switch (avail_fmt) {
222
        case SparseFormat::kCOO:
223
224
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
225
226
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                fanouts[etype], prob_or_mask[etype], replace));
227
228
          } else {
            sampled_coo = aten::COORowWiseSampling(
229
230
                hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],
                prob_or_mask[etype], replace);
231
232
          }
          break;
233
        case SparseFormat::kCSR:
234
235
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
236
          sampled_coo = aten::CSRRowWiseSampling(
237
238
              hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
239
          break;
240
        case SparseFormat::kCSC:
241
242
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseSampling(
243
244
              hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
245
246
247
248
249
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
250

251
      subrels[etype] = UnitGraph::CreateFromCOO(
252
253
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
254
      induced_edges[etype] = sampled_coo.data;
255
256
257
258
    }
  }

  HeteroSubgraph ret;
259
260
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
261
262
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
263
  if (!exclude_edges.empty()) {
264
    return ExcludeCertainEdges(ret, exclude_edges).first;
265
  }
266
267
268
  return ret;
}

269
HeteroSubgraph SampleNeighborsEType(
270
    const HeteroGraphPtr hg, const IdArray nodes,
271
    const std::vector<int64_t>& eid2etype_offset,
272
273
    const std::vector<int64_t>& fanouts, EdgeDir dir,
    const std::vector<FloatArray>& prob, bool replace,
274
    bool rowwise_etype_sorted) {
275
  CHECK_EQ(1, hg->NumVertexTypes())
276
      << "SampleNeighborsEType only work with homogeneous graph";
277
  CHECK_EQ(1, hg->NumEdgeTypes())
278
      << "SampleNeighborsEType only work with homogeneous graph";
279
280
281
282
283
284
285

  std::vector<HeteroGraphPtr> subrels(1);
  std::vector<IdArray> induced_edges(1);
  const int64_t num_nodes = nodes->shape[0];
  dgl_type_t etype = 0;
  const dgl_type_t src_vtype = 0;
  const dgl_type_t dst_vtype = 0;
286
287
288
289
290
291
292
293
294
295
296

  bool same_fanout = true;
  int64_t fanout_value = fanouts[0];
  for (auto fanout : fanouts) {
    if (fanout != fanout_value) {
      same_fanout = false;
      break;
    }
  }

  if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
297
298
299
    subrels[etype] = UnitGraph::Empty(
        1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
        hg->DataType(), hg->Context());
300
301
    induced_edges[etype] = aten::NullArray();
  } else {
302
    COOMatrix sampled_coo;
303
304
    // sample from graph
    // the edge type is stored in etypes
305
    auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
306
307
308
309
310
    auto avail_fmt = hg->SelectFormat(etype, req_fmt);
    switch (avail_fmt) {
      case SparseFormat::kCOO:
        if (dir == EdgeDir::kIn) {
          sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
311
312
              aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,
              eid2etype_offset, fanouts, prob, replace));
313
314
        } else {
          sampled_coo = aten::COORowWisePerEtypeSampling(
315
316
              hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
              replace);
317
318
319
320
321
        }
        break;
      case SparseFormat::kCSR:
        CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
322
323
324
            hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
        break;
325
326
327
      case SparseFormat::kCSC:
        CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
328
329
            hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
330
331
332
333
334
335
336
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }

    subrels[etype] = UnitGraph::CreateFromCOO(
337
338
        1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,
        sampled_coo.col);
339
340
341
342
    induced_edges[etype] = sampled_coo.data;
  }

  HeteroSubgraph ret;
343
344
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
345
346
347
348
349
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

350
HeteroSubgraph SampleNeighborsTopk(
351
352
353
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& k, EdgeDir dir,
    const std::vector<FloatArray>& weight, bool ascending) {
354
355
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
356
      << "Number of node ID tensors must match the number of node types.";
357
  CHECK_EQ(k.size(), hg->NumEdgeTypes())
358
      << "Number of k values must match the number of edge types.";
359
  CHECK_EQ(weight.size(), hg->NumEdgeTypes())
360
      << "Number of weight tensors must match the number of edge types.";
361
362
363
364
365
366
367

  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
368
369
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
370
    const int64_t num_nodes = nodes_ntype->shape[0];
371
372
    if (num_nodes == 0 || k[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
373
      subrels[etype] = UnitGraph::Empty(
374
375
376
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), hg->Context());
377
      induced_edges[etype] = aten::NullArray();
378
379
    } else {
      // sample from one relation graph
380
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
381
382
383
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      COOMatrix sampled_coo;
      switch (avail_fmt) {
384
        case SparseFormat::kCOO:
385
386
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
387
388
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                k[etype], weight[etype], ascending));
389
390
          } else {
            sampled_coo = aten::COORowWiseTopk(
391
392
                hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],
                ascending);
393
394
          }
          break;
395
        case SparseFormat::kCSR:
396
397
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
398
          sampled_coo = aten::CSRRowWiseTopk(
399
400
              hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
401
          break;
402
        case SparseFormat::kCSC:
403
404
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseTopk(
405
406
              hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
407
408
409
410
411
412
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
      subrels[etype] = UnitGraph::CreateFromCOO(
413
414
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
415
      induced_edges[etype] = sampled_coo.data;
416
417
418
419
    }
  }

  HeteroSubgraph ret;
420
421
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
422
423
424
425
426
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

427
HeteroSubgraph SampleNeighborsBiased(
428
429
430
431
432
    const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,
    const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
    const bool replace) {
  CHECK_EQ(hg->NumEdgeTypes(), 1)
      << "Only homogeneous or bipartite graphs are supported";
433
434
435
436
437
438
  auto pair = hg->meta_graph()->FindEdge(0);
  const dgl_type_t src_vtype = pair.first;
  const dgl_type_t dst_vtype = pair.second;
  const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;

  // sanity check
439
440
  CHECK_EQ(tag_offset->ndim, 2)
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
441
  CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
442
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
443
  CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
444
      << "The sizes of tag_offset and bias are inconsistent";
445
446
447
448
449
450

  const int64_t num_nodes = nodes->shape[0];
  HeteroGraphPtr subrel;
  IdArray induced_edges;
  const dgl_type_t etype = 0;
  if (num_nodes == 0 || fanout == 0) {
451
452
    // Nothing to sample for this etype, create a placeholder relation graph
    subrel = UnitGraph::Empty(
453
        hg->GetRelationGraph(etype)->NumVertexTypes(),
454
455
456
457
458
459
460
461
        hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),
        hg->Context());
    induced_edges = aten::NullArray();
  } else {
    // sample from one relation graph
    const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
    const auto created_fmt = hg->GetCreatedFormats();
    COOMatrix sampled_coo;
462

463
464
465
466
    switch (req_fmt) {
      case CSR_CODE:
        CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
467
            hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);
468
469
470
471
        break;
      case CSC_CODE:
        CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
472
            hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);
473
474
475
476
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
477
    }
478
479
480
481
482
    subrel = UnitGraph::CreateFromCOO(
        hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
        sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
    induced_edges = sampled_coo.data;
  }
483
484

  HeteroSubgraph ret;
485
486
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
487
488
489
490
491
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = {induced_edges};
  return ret;
}

492
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      IdArray nodes = args[1];
      const std::vector<int64_t>& eid2etype_offset =
          ListValueToVector<int64_t>(args[2]);
      IdArray fanout = args[3];
      const std::string dir_str = args[4];
      const auto& prob = ListValueToVector<FloatArray>(args[5]);
      const bool replace = args[6];
      const bool rowwise_etype_sorted = args[7];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
      CHECK_INT64(fanout, "fanout");
      std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsEType(
          hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,
          rowwise_etype_sorted);
      *rv = HeteroSubgraphRef(subg);
    });
516

517
518
519
520
521
522
523
524
525
526
527
DGL_REGISTER_GLOBAL("sampling.labor._CAPI_DGLSampleLabors")
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray fanouts_array = args[2];
      const auto& fanouts = fanouts_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& prob = ListValueToVector<FloatArray>(args[4]);
      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
      const int importance_sampling = args[6];
      const IdArray random_seed = args[7];
528
529
      const double seed2_contribution = args[8];
      const auto& NIDs = ListValueToVector<IdArray>(args[9]);
530
531
532
533
534
535
536
537
538

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg_ptr(new HeteroSubgraph);

      auto&& subg_importances = sampling::SampleLabors(
          hg.sptr(), nodes, fanouts, dir, prob, exclude_edges,
539
          importance_sampling, random_seed, seed2_contribution, NIDs);
540
541
542
543
544
545
546
547
548
      *subg_ptr = subg_importances.first;
      List<Value> ret_val;
      ret_val.push_back(Value(subg_ptr));
      for (auto& imp : subg_importances.second)
        ret_val.push_back(Value(MakeValue(imp)));

      *rv = ret_val;
    });

549
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray fanouts_array = args[2];
      const auto& fanouts = fanouts_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);
      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighbors(
          hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);

      *rv = HeteroSubgraphRef(subg);
    });
570
571

DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray k_array = args[2];
      const auto& k = k_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& weight = ListValueToVector<FloatArray>(args[4]);
      const bool ascending = args[5];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsTopk(
          hg.sptr(), nodes, k, dir, weight, ascending);

      *rv = HeteroGraphRef(subg);
    });
591

592
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const IdArray nodes = args[1];
      const int64_t fanout = args[2];
      const NDArray bias = args[3];
      const NDArray tag_offset = args[4];
      const std::string dir_str = args[5];
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsBiased(
          hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);

      *rv = HeteroGraphRef(subg);
    });
612

613
614
}  // namespace sampling
}  // namespace dgl