neighbor.cc 18.2 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020-2021 by Contributors
3
4
 * @file graph/sampling/neighbor.cc
 * @brief Definition of neighborhood-based sampler APIs.
5
6
7
 */

#include <dgl/array.h>
8
#include <dgl/aten/macro.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
#include <dgl/sampling/neighbor.h>
12

13
14
15
16
17
18
19
20
21
#include "../../../c_api_common.h"
#include "../../unit_graph.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {
namespace sampling {

22
HeteroSubgraph ExcludeCertainEdges(
23
24
25
26
27
28
29
30
31
32
33
34
35
    const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges) {
  HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
  std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
  std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());

  for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
    IdArray edge_ids = Range(
        0, sg.induced_edges[etype]->shape[0],
        sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);
    if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
      remain_edges[etype] = edge_ids;
      remain_induced_edges[etype] = sg.induced_edges[etype];
      continue;
36
    }
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
      IdType* idx_data = edge_ids.Ptr<IdType>();
      IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
      const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
      std::sort(
          exclude_edges[etype].Ptr<IdType>(),
          exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
      const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
      IdType outId = 0;
      for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
        if (!std::binary_search(
                exclude_edges_data, exclude_edges_data + exclude_edges_len,
                induced_edges_data[i])) {
          induced_edges_data[outId] = induced_edges_data[i];
          idx_data[outId] = idx_data[i];
          ++outId;
        }
      }
      remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
      remain_induced_edges[etype] =
          aten::IndexSelect(sg.induced_edges[etype], 0, outId);
    });
  }
  HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
  subg.induced_edges = std::move(remain_induced_edges);
  return subg;
63
64
}

65
HeteroSubgraph SampleNeighbors(
66
67
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& fanouts, EdgeDir dir,
68
    const std::vector<NDArray>& prob_or_mask,
69
    const std::vector<IdArray>& exclude_edges, bool replace) {
70
71
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
72
      << "Number of node ID tensors must match the number of node types.";
73
  CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
74
      << "Number of fanout values must match the number of edge types.";
75
  CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
76
      << "Number of probability tensors must match the number of edge types.";
77

78
  DGLContext ctx = aten::GetContextOf(nodes);
79

80
81
82
83
84
85
  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
86
87
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
88
    const int64_t num_nodes = nodes_ntype->shape[0];
89

90
91
    if (num_nodes == 0 || fanouts[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
92
      subrels[etype] = UnitGraph::Empty(
93
94
95
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), ctx);
96
      induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
97
    } else {
98
      COOMatrix sampled_coo;
99
      // sample from one relation graph
100
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
101
102
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      switch (avail_fmt) {
103
        case SparseFormat::kCOO:
104
105
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
106
107
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                fanouts[etype], prob_or_mask[etype], replace));
108
109
          } else {
            sampled_coo = aten::COORowWiseSampling(
110
111
                hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],
                prob_or_mask[etype], replace);
112
113
          }
          break;
114
        case SparseFormat::kCSR:
115
116
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
117
          sampled_coo = aten::CSRRowWiseSampling(
118
119
              hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
120
          break;
121
        case SparseFormat::kCSC:
122
123
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseSampling(
124
125
              hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],
              prob_or_mask[etype], replace);
126
127
128
129
130
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
131

132
      subrels[etype] = UnitGraph::CreateFromCOO(
133
134
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
135
      induced_edges[etype] = sampled_coo.data;
136
137
138
139
    }
  }

  HeteroSubgraph ret;
140
141
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
142
143
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
144
145
146
  if (!exclude_edges.empty()) {
    return ExcludeCertainEdges(ret, exclude_edges);
  }
147
148
149
  return ret;
}

150
HeteroSubgraph SampleNeighborsEType(
151
    const HeteroGraphPtr hg, const IdArray nodes,
152
    const std::vector<int64_t>& eid2etype_offset,
153
154
    const std::vector<int64_t>& fanouts, EdgeDir dir,
    const std::vector<FloatArray>& prob, bool replace,
155
    bool rowwise_etype_sorted) {
156
  CHECK_EQ(1, hg->NumVertexTypes())
157
      << "SampleNeighborsEType only work with homogeneous graph";
158
  CHECK_EQ(1, hg->NumEdgeTypes())
159
      << "SampleNeighborsEType only work with homogeneous graph";
160
161
162
163
164
165
166

  std::vector<HeteroGraphPtr> subrels(1);
  std::vector<IdArray> induced_edges(1);
  const int64_t num_nodes = nodes->shape[0];
  dgl_type_t etype = 0;
  const dgl_type_t src_vtype = 0;
  const dgl_type_t dst_vtype = 0;
167
168
169
170
171
172
173
174
175
176
177

  bool same_fanout = true;
  int64_t fanout_value = fanouts[0];
  for (auto fanout : fanouts) {
    if (fanout != fanout_value) {
      same_fanout = false;
      break;
    }
  }

  if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
178
179
180
    subrels[etype] = UnitGraph::Empty(
        1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
        hg->DataType(), hg->Context());
181
182
    induced_edges[etype] = aten::NullArray();
  } else {
183
    COOMatrix sampled_coo;
184
185
    // sample from graph
    // the edge type is stored in etypes
186
    auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
187
188
189
190
191
    auto avail_fmt = hg->SelectFormat(etype, req_fmt);
    switch (avail_fmt) {
      case SparseFormat::kCOO:
        if (dir == EdgeDir::kIn) {
          sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
192
193
              aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,
              eid2etype_offset, fanouts, prob, replace));
194
195
        } else {
          sampled_coo = aten::COORowWisePerEtypeSampling(
196
197
              hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
              replace);
198
199
200
201
202
        }
        break;
      case SparseFormat::kCSR:
        CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
203
204
205
            hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
        break;
206
207
208
      case SparseFormat::kCSC:
        CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
        sampled_coo = aten::CSRRowWisePerEtypeSampling(
209
210
            hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
            replace, rowwise_etype_sorted);
211
212
213
214
215
216
217
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
    }

    subrels[etype] = UnitGraph::CreateFromCOO(
218
219
        1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,
        sampled_coo.col);
220
221
222
223
    induced_edges[etype] = sampled_coo.data;
  }

  HeteroSubgraph ret;
224
225
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
226
227
228
229
230
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

231
HeteroSubgraph SampleNeighborsTopk(
232
233
234
    const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
    const std::vector<int64_t>& k, EdgeDir dir,
    const std::vector<FloatArray>& weight, bool ascending) {
235
236
  // sanity check
  CHECK_EQ(nodes.size(), hg->NumVertexTypes())
237
      << "Number of node ID tensors must match the number of node types.";
238
  CHECK_EQ(k.size(), hg->NumEdgeTypes())
239
      << "Number of k values must match the number of edge types.";
240
  CHECK_EQ(weight.size(), hg->NumEdgeTypes())
241
      << "Number of weight tensors must match the number of edge types.";
242
243
244
245
246
247
248

  std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
  std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
  for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
    auto pair = hg->meta_graph()->FindEdge(etype);
    const dgl_type_t src_vtype = pair.first;
    const dgl_type_t dst_vtype = pair.second;
249
250
    const IdArray nodes_ntype =
        nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
251
    const int64_t num_nodes = nodes_ntype->shape[0];
252
253
    if (num_nodes == 0 || k[etype] == 0) {
      // Nothing to sample for this etype, create a placeholder relation graph
254
      subrels[etype] = UnitGraph::Empty(
255
256
257
          hg->GetRelationGraph(etype)->NumVertexTypes(),
          hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
          hg->DataType(), hg->Context());
258
      induced_edges[etype] = aten::NullArray();
259
260
    } else {
      // sample from one relation graph
261
      auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
262
263
264
      auto avail_fmt = hg->SelectFormat(etype, req_fmt);
      COOMatrix sampled_coo;
      switch (avail_fmt) {
265
        case SparseFormat::kCOO:
266
267
          if (dir == EdgeDir::kIn) {
            sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
268
269
                aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
                k[etype], weight[etype], ascending));
270
271
          } else {
            sampled_coo = aten::COORowWiseTopk(
272
273
                hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],
                ascending);
274
275
          }
          break;
276
        case SparseFormat::kCSR:
277
278
          CHECK(dir == EdgeDir::kOut)
              << "Cannot sample out edges on CSC matrix.";
279
          sampled_coo = aten::CSRRowWiseTopk(
280
281
              hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
282
          break;
283
        case SparseFormat::kCSC:
284
285
          CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
          sampled_coo = aten::CSRRowWiseTopk(
286
287
              hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],
              ascending);
288
289
290
291
292
293
          sampled_coo = aten::COOTranspose(sampled_coo);
          break;
        default:
          LOG(FATAL) << "Unsupported sparse format.";
      }
      subrels[etype] = UnitGraph::CreateFromCOO(
294
295
          hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
          sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
296
      induced_edges[etype] = sampled_coo.data;
297
298
299
300
    }
  }

  HeteroSubgraph ret;
301
302
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
303
304
305
306
307
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = std::move(induced_edges);
  return ret;
}

308
HeteroSubgraph SampleNeighborsBiased(
309
310
311
312
313
    const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,
    const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
    const bool replace) {
  CHECK_EQ(hg->NumEdgeTypes(), 1)
      << "Only homogeneous or bipartite graphs are supported";
314
315
316
317
318
319
  auto pair = hg->meta_graph()->FindEdge(0);
  const dgl_type_t src_vtype = pair.first;
  const dgl_type_t dst_vtype = pair.second;
  const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;

  // sanity check
320
321
  CHECK_EQ(tag_offset->ndim, 2)
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
322
  CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
323
      << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
324
  CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
325
      << "The sizes of tag_offset and bias are inconsistent";
326
327
328
329
330
331

  const int64_t num_nodes = nodes->shape[0];
  HeteroGraphPtr subrel;
  IdArray induced_edges;
  const dgl_type_t etype = 0;
  if (num_nodes == 0 || fanout == 0) {
332
333
    // Nothing to sample for this etype, create a placeholder relation graph
    subrel = UnitGraph::Empty(
334
        hg->GetRelationGraph(etype)->NumVertexTypes(),
335
336
337
338
339
340
341
342
        hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),
        hg->Context());
    induced_edges = aten::NullArray();
  } else {
    // sample from one relation graph
    const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
    const auto created_fmt = hg->GetCreatedFormats();
    COOMatrix sampled_coo;
343

344
345
346
347
    switch (req_fmt) {
      case CSR_CODE:
        CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
348
            hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);
349
350
351
352
        break;
      case CSC_CODE:
        CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required.";
        sampled_coo = aten::CSRRowWiseSamplingBiased(
353
            hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);
354
355
356
357
        sampled_coo = aten::COOTranspose(sampled_coo);
        break;
      default:
        LOG(FATAL) << "Unsupported sparse format.";
358
    }
359
360
361
362
363
    subrel = UnitGraph::CreateFromCOO(
        hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
        sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
    induced_edges = sampled_coo.data;
  }
364
365

  HeteroSubgraph ret;
366
367
  ret.graph =
      CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
368
369
370
371
372
  ret.induced_vertices.resize(hg->NumVertexTypes());
  ret.induced_edges = {induced_edges};
  return ret;
}

373
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      IdArray nodes = args[1];
      const std::vector<int64_t>& eid2etype_offset =
          ListValueToVector<int64_t>(args[2]);
      IdArray fanout = args[3];
      const std::string dir_str = args[4];
      const auto& prob = ListValueToVector<FloatArray>(args[5]);
      const bool replace = args[6];
      const bool rowwise_etype_sorted = args[7];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
      CHECK_INT64(fanout, "fanout");
      std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsEType(
          hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,
          rowwise_etype_sorted);
      *rv = HeteroSubgraphRef(subg);
    });
397

398
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray fanouts_array = args[2];
      const auto& fanouts = fanouts_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);
      const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighbors(
          hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);

      *rv = HeteroSubgraphRef(subg);
    });
419
420

DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const auto& nodes = ListValueToVector<IdArray>(args[1]);
      IdArray k_array = args[2];
      const auto& k = k_array.ToVector<int64_t>();
      const std::string dir_str = args[3];
      const auto& weight = ListValueToVector<FloatArray>(args[4]);
      const bool ascending = args[5];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsTopk(
          hg.sptr(), nodes, k, dir, weight, ascending);

      *rv = HeteroGraphRef(subg);
    });
440

441
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      const IdArray nodes = args[1];
      const int64_t fanout = args[2];
      const NDArray bias = args[3];
      const NDArray tag_offset = args[4];
      const std::string dir_str = args[5];
      const bool replace = args[6];

      CHECK(dir_str == "in" || dir_str == "out")
          << "Invalid edge direction. Must be \"in\" or \"out\".";
      EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;

      std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
      *subg = sampling::SampleNeighborsBiased(
          hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);

      *rv = HeteroGraphRef(subg);
    });
461

462
463
}  // namespace sampling
}  // namespace dgl