graph_apis.cc 18.8 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/graph.cc
 * \brief DGL graph index APIs
 */
Minjie Wang's avatar
Minjie Wang committed
6
#include <dgl/graph.h>
7
#include <dgl/immutable_graph.h>
Minjie Wang's avatar
Minjie Wang committed
8
#include <dgl/graph_op.h>
Lingfan Yu's avatar
Lingfan Yu committed
9
#include "../c_api_common.h"
Minjie Wang's avatar
Minjie Wang committed
10

11
12
13
14
15
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
Minjie Wang's avatar
Minjie Wang committed
16
17

namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
18

Minjie Wang's avatar
Minjie Wang committed
19
namespace {
Minjie Wang's avatar
Minjie Wang committed
20
// Convert EdgeArray structure to PackedFunc.
21
22
template<class EdgeArray>
PackedFunc ConvertEdgeArrayToPackedFunc(const EdgeArray& ea) {
23
  auto body = [ea] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
24
      const int which = args[0];
25
      if (which == 0) {
Minjie Wang's avatar
Minjie Wang committed
26
27
28
29
30
31
32
33
34
35
36
37
        *rv = std::move(ea.src);
      } else if (which == 1) {
        *rv = std::move(ea.dst);
      } else if (which == 2) {
        *rv = std::move(ea.id);
      } else {
        LOG(FATAL) << "invalid choice";
      }
    };
  return PackedFunc(body);
}

38
39
40
41
42
43
44
45
46
47
48
49
50
// Convert CSRArray structure to PackedFunc.
PackedFunc ConvertAdjToPackedFunc(const std::vector<IdArray>& ea) {
  auto body = [ea] (DGLArgs args, DGLRetValue* rv) {
      const int which = args[0];
      if ((size_t) which < ea.size()) {
        *rv = std::move(ea[which]);
      } else {
        LOG(FATAL) << "invalid choice";
      }
    };
  return PackedFunc(body);
}

Minjie Wang's avatar
Minjie Wang committed
51
52
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
53
  auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
54
      const int which = args[0];
Minjie Wang's avatar
Minjie Wang committed
55
      if (which == 0) {
56
        GraphInterface* gptr = sg.graph->Reset();
Minjie Wang's avatar
Minjie Wang committed
57
58
        GraphHandle ghandle = gptr;
        *rv = ghandle;
59
      } else if (which == 1) {
Minjie Wang's avatar
Minjie Wang committed
60
        *rv = std::move(sg.induced_vertices);
61
      } else if (which == 2) {
Minjie Wang's avatar
Minjie Wang committed
62
        *rv = std::move(sg.induced_edges);
63
64
65
      } else {
        LOG(FATAL) << "invalid choice";
      }
Minjie Wang's avatar
Minjie Wang committed
66
67
68
    };
  return PackedFunc(body);
}
Minjie Wang's avatar
Minjie Wang committed
69

70
71
72
// Convert Sampled Subgraph structures to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const std::vector<SampledSubgraph>& sg) {
  auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
73
      const size_t which = args[0];
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
      if (which < sg.size()) {
        GraphInterface* gptr = sg[which].graph->Reset();
        GraphHandle ghandle = gptr;
        *rv = ghandle;
      } else if (which >= sg.size() && which < sg.size() * 2) {
        *rv = std::move(sg[which - sg.size()].induced_vertices);
      } else if (which >= sg.size() * 2 && which < sg.size() * 3) {
        *rv = std::move(sg[which - sg.size() * 2].induced_edges);
      } else if (which >= sg.size() * 3 && which < sg.size() * 4) {
        *rv = std::move(sg[which - sg.size() * 3].layer_ids);
      } else if (which >= sg.size() * 4 && which < sg.size() * 5) {
        *rv = std::move(sg[which - sg.size() * 4].sample_prob);
      } else {
        LOG(FATAL) << "invalid choice";
      }
    };
  // TODO(minjie): figure out a better way of returning a complex results.
  return PackedFunc(body);
}

Minjie Wang's avatar
Minjie Wang committed
94
}  // namespace
Minjie Wang's avatar
Minjie Wang committed
95

96
97
98
///////////////////////////// Graph API ///////////////////////////////////

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
99
.set_body([] (DGLArgs args, DGLRetValue* rv) {
100
101
    bool multigraph = static_cast<bool>(args[0]);
    GraphHandle ghandle = new Graph(multigraph);
Minjie Wang's avatar
Minjie Wang committed
102
103
    *rv = ghandle;
  });
Minjie Wang's avatar
Minjie Wang committed
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    const IdArray src_ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
    const IdArray dst_ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray edge_ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
    const bool multigraph = static_cast<bool>(args[3]);
    const int64_t num_nodes = static_cast<int64_t>(args[4]);
    const bool readonly = static_cast<bool>(args[5]);
    GraphHandle ghandle;
    if (readonly)
      ghandle = new ImmutableGraph(src_ids, dst_ids, edge_ids, num_nodes, multigraph);
    else
      ghandle = new Graph(src_ids, dst_ids, edge_ids, num_nodes, multigraph);
    *rv = ghandle;
  });

121
122
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
123
    GraphHandle ghandle = args[0];
124
    GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
125
126
    delete gptr;
  });
Minjie Wang's avatar
Minjie Wang committed
127

128
129
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
130
    GraphHandle ghandle = args[0];
131
    GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
132
133
134
    uint64_t num_vertices = args[1];
    gptr->AddVertices(num_vertices);
  });
Minjie Wang's avatar
Minjie Wang committed
135

136
137
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
138
    GraphHandle ghandle = args[0];
139
    GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
140
141
142
143
    const dgl_id_t src = args[1];
    const dgl_id_t dst = args[2];
    gptr->AddEdge(src, dst);
  });
Minjie Wang's avatar
Minjie Wang committed
144

145
146
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
147
    GraphHandle ghandle = args[0];
148
    GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
149
150
    const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
Minjie Wang's avatar
Minjie Wang committed
151
152
153
    gptr->AddEdges(src, dst);
  });

154
155
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
156
    GraphHandle ghandle = args[0];
157
    GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
158
159
    gptr->Clear();
  });
Minjie Wang's avatar
Minjie Wang committed
160

161
162
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
163
164
    GraphHandle ghandle = args[0];
    // NOTE: not const since we have caches
165
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
166
167
168
    *rv = gptr->IsMultigraph();
  });

169
170
171
172
173
174
175
176
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    GraphHandle ghandle = args[0];
    // NOTE: not const since we have caches
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
    *rv = gptr->IsReadonly();
  });

177
178
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
179
    GraphHandle ghandle = args[0];
180
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
181
182
    *rv = static_cast<int64_t>(gptr->NumVertices());
  });
Minjie Wang's avatar
Minjie Wang committed
183

184
185
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
186
    GraphHandle ghandle = args[0];
187
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
188
189
190
    *rv = static_cast<int64_t>(gptr->NumEdges());
  });

191
192
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
193
    GraphHandle ghandle = args[0];
194
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
195
196
197
198
    const dgl_id_t vid = args[1];
    *rv = gptr->HasVertex(vid);
  });

199
200
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
201
    GraphHandle ghandle = args[0];
202
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
203
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
Minjie Wang's avatar
Minjie Wang committed
204
205
206
    *rv = gptr->HasVertices(vids);
  });

207
208
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
209
210
211
212
213
    const IdArray parent_vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
    const IdArray query = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    *rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
  });

214
215
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
216
    GraphHandle ghandle = args[0];
217
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
218
219
    const dgl_id_t src = args[1];
    const dgl_id_t dst = args[2];
220
    *rv = gptr->HasEdgeBetween(src, dst);
Minjie Wang's avatar
Minjie Wang committed
221
222
  });

223
224
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
225
    GraphHandle ghandle = args[0];
226
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
227
228
    const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
229
    *rv = gptr->HasEdgesBetween(src, dst);
Minjie Wang's avatar
Minjie Wang committed
230
231
  });

232
233
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
234
    GraphHandle ghandle = args[0];
235
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
236
237
238
239
240
    const dgl_id_t vid = args[1];
    const uint64_t radius = args[2];
    *rv = gptr->Predecessors(vid, radius);
  });

241
242
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
243
    GraphHandle ghandle = args[0];
244
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
245
246
247
248
249
    const dgl_id_t vid = args[1];
    const uint64_t radius = args[2];
    *rv = gptr->Successors(vid, radius);
  });

250
251
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
252
    GraphHandle ghandle = args[0];
253
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
254
255
    const dgl_id_t src = args[1];
    const dgl_id_t dst = args[2];
256
    *rv = gptr->EdgeId(src, dst);
Minjie Wang's avatar
Minjie Wang committed
257
258
  });

259
260
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
261
    GraphHandle ghandle = args[0];
262
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
263
264
    const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
265
266
267
    *rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst));
  });

268
269
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
270
    GraphHandle ghandle = args[0];
271
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
272
273
    const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    *rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids));
Minjie Wang's avatar
Minjie Wang committed
274
275
  });

276
277
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
278
    GraphHandle ghandle = args[0];
279
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
280
    const dgl_id_t vid = args[1];
281
    *rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
Minjie Wang's avatar
Minjie Wang committed
282
283
  });

284
285
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
286
    GraphHandle ghandle = args[0];
287
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
288
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
289
    *rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
Minjie Wang's avatar
Minjie Wang committed
290
291
  });

292
293
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
294
    GraphHandle ghandle = args[0];
295
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
296
    const dgl_id_t vid = args[1];
297
    *rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
Minjie Wang's avatar
Minjie Wang committed
298
299
  });

300
301
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
302
    GraphHandle ghandle = args[0];
303
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
304
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
305
    *rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
Minjie Wang's avatar
Minjie Wang committed
306
307
  });

308
309
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
310
    GraphHandle ghandle = args[0];
311
312
313
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
    std::string order = args[1];
    *rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(order));
Minjie Wang's avatar
Minjie Wang committed
314
315
  });

316
317
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
318
    GraphHandle ghandle = args[0];
319
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
320
321
322
323
    const dgl_id_t vid = args[1];
    *rv = static_cast<int64_t>(gptr->InDegree(vid));
  });

324
325
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
326
    GraphHandle ghandle = args[0];
327
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
328
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
Minjie Wang's avatar
Minjie Wang committed
329
330
331
    *rv = gptr->InDegrees(vids);
  });

332
333
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
334
    GraphHandle ghandle = args[0];
335
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
336
337
338
339
    const dgl_id_t vid = args[1];
    *rv = static_cast<int64_t>(gptr->OutDegree(vid));
  });

340
341
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
342
    GraphHandle ghandle = args[0];
343
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
344
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
Minjie Wang's avatar
Minjie Wang committed
345
346
    *rv = gptr->OutDegrees(vids);
  });
Minjie Wang's avatar
Minjie Wang committed
347

348
349
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
350
    GraphHandle ghandle = args[0];
351
    const GraphInterface* gptr = static_cast<GraphInterface*>(ghandle);
Minjie Wang's avatar
Minjie Wang committed
352
353
354
355
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    *rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
  });

356
357
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
358
    GraphHandle ghandle = args[0];
359
    const GraphInterface *gptr = static_cast<GraphInterface*>(ghandle);
360
361
362
363
    const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    *rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids));
  });

364
365
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
366
367
368
369
370
    void* list = args[0];
    GraphHandle* inhandles = static_cast<GraphHandle*>(list);
    int list_size = args[1];
    std::vector<const Graph*> graphs;
    for (int i = 0; i < list_size; ++i) {
371
372
373
      const GraphInterface *ptr = static_cast<const GraphInterface *>(inhandles[i]);
      const Graph* gr = dynamic_cast<const Graph*>(ptr);
      CHECK(gr) << "_CAPI_DGLDisjointUnion isn't implemented in immutable graph";
374
375
376
      graphs.push_back(gr);
    }
    Graph* gptr = new Graph();
Minjie Wang's avatar
Minjie Wang committed
377
    *gptr = GraphOp::DisjointUnion(std::move(graphs));
378
379
380
381
    GraphHandle ghandle = gptr;
    *rv = ghandle;
  });

382
383
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
384
    GraphHandle ghandle = args[0];
385
386
387
    const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
    const Graph* gptr = dynamic_cast<const Graph*>(ptr);
    CHECK(gptr) << "_CAPI_DGLDisjointPartitionByNum isn't implemented in immutable graph";
Minjie Wang's avatar
Minjie Wang committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    int64_t num = args[1];
    std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
    // return the pointer array as an integer array
    const int64_t len = rst.size();
    NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
    int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
    for (size_t i = 0; i < rst.size(); ++i) {
      Graph* ptr = new Graph();
      *ptr = std::move(rst[i]);
      ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
    }
    *rv = ptr_array;
  });

402
403
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
404
    GraphHandle ghandle = args[0];
405
406
407
    const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
    const Graph* gptr = dynamic_cast<const Graph*>(ptr);
    CHECK(gptr) << "_CAPI_DGLDisjointPartitionBySizes isn't implemented in immutable graph";
Minjie Wang's avatar
Minjie Wang committed
408
409
410
411
412
413
414
415
416
417
418
419
420
    const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
    // return the pointer array as an integer array
    const int64_t len = rst.size();
    NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
    int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
    for (size_t i = 0; i < rst.size(); ++i) {
      Graph* ptr = new Graph();
      *ptr = std::move(rst[i]);
      ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
    }
    *rv = ptr_array;
  });
GaiYu0's avatar
cpp lg  
GaiYu0 committed
421

422
423
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GaiYu0's avatar
cpp lg  
GaiYu0 committed
424
425
    GraphHandle ghandle = args[0];
    bool backtracking = args[1];
426
427
428
    const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
    const Graph* gptr = dynamic_cast<const Graph*>(ptr);
    CHECK(gptr) << "_CAPI_DGLGraphLineGraph isn't implemented in immutable graph";
GaiYu0's avatar
cpp lg  
GaiYu0 committed
429
430
431
432
433
    Graph* lgptr = new Graph();
    *lgptr = GraphOp::LineGraph(gptr, backtracking);
    GraphHandle lghandle = lgptr;
    *rv = lghandle;
  });
GaiYu0's avatar
GaiYu0 committed
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
template<int num_seeds>
void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) {
  GraphHandle ghandle = args[0];
  std::vector<IdArray> seeds(num_seeds);
  for (size_t i = 0; i < seeds.size(); i++)
    seeds[i] = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[i + 1]));
  std::string neigh_type = args[num_seeds + 1];
  const int num_hops = args[num_seeds + 2];
  const int num_neighbors = args[num_seeds + 3];
  const int num_valid_seeds = args[num_seeds + 4];
  const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
  const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
  CHECK(gptr) << "sampling isn't implemented in mutable graph";
  CHECK(num_valid_seeds <= num_seeds);
  std::vector<SampledSubgraph> subgs(seeds.size());
#pragma omp parallel for
  for (int i = 0; i < num_valid_seeds; i++) {
    subgs[i] = gptr->NeighborUniformSample(seeds[i], neigh_type, num_hops, num_neighbors);
  }
  *rv = ConvertSubgraphToPackedFunc(subgs);
}

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling")
.set_body(CAPI_NeighborUniformSample<1>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling2")
.set_body(CAPI_NeighborUniformSample<2>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling4")
.set_body(CAPI_NeighborUniformSample<4>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling8")
.set_body(CAPI_NeighborUniformSample<8>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling16")
.set_body(CAPI_NeighborUniformSample<16>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling32")
.set_body(CAPI_NeighborUniformSample<32>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling64")
.set_body(CAPI_NeighborUniformSample<64>);
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphUniformSampling128")
.set_body(CAPI_NeighborUniformSample<128>);

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphHandle ghandle = args[0];
    bool transpose = args[1];
    std::string format = args[2];
    const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
    auto res = ptr->GetAdj(transpose, format);
    *rv = ConvertAdjToPackedFunc(res);
  });

Minjie Wang's avatar
Minjie Wang committed
484
}  // namespace dgl