graph_apis.cc 2.13 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>

using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;

namespace dgl {

typedef void* GraphHandle;

void DGLGraphCreate(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = new Graph();
  *rv = ghandle;
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphCreate")
.set_body(DGLGraphCreate);

void DGLGraphFree(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  Graph* gptr = static_cast<Graph*>(ghandle);
  delete gptr;
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphFree")
.set_body(DGLGraphFree);

void DGLGraphAddVertices(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  Graph* gptr = static_cast<Graph*>(ghandle);
  uint64_t num_vertices = args[1];
  gptr->AddVertices(num_vertices);
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddVertices")
.set_body(DGLGraphAddVertices);

void DGLGraphAddEdge(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  Graph* gptr = static_cast<Graph*>(ghandle);
  const dgl_id_t src = args[1];
  const dgl_id_t dst = args[2];
  gptr->AddEdge(src, dst);
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdge")
.set_body(DGLGraphAddEdge);

void DGLGraphAddEdges(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  Graph* gptr = static_cast<Graph*>(ghandle);
  const IdArray src = args[1];
  const IdArray dst = args[2];
  gptr->AddEdges(src, dst);
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphAddEdges")
.set_body(DGLGraphAddEdges);

void DGLGraphNumVertices(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  const Graph* gptr = static_cast<Graph*>(ghandle);
  *rv = static_cast<int64_t>(gptr->NumVertices());
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumVertices")
.set_body(DGLGraphNumVertices);

void DGLGraphNumEdges(TVMArgs args, TVMRetValue* rv) {
  GraphHandle ghandle = args[0];
  const Graph* gptr = static_cast<Graph*>(ghandle);
  *rv = static_cast<int64_t>(gptr->NumEdges());
}

TVM_REGISTER_GLOBAL("cgraph._CAPI_DGLGraphNumEdges")
.set_body(DGLGraphNumEdges);

}  // namespace dgl