scheduler_apis.cc 2.4 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file scheduler/scheduler_apis.cc
 * \brief DGL scheduler APIs
 */
Lingfan Yu's avatar
Lingfan Yu committed
6
7
#include <dgl/graph.h>
#include <dgl/scheduler.h>
8
#include "../c_api_common.h"
Lingfan Yu's avatar
Lingfan Yu committed
9

10
11
12
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue;
using dgl::runtime::NDArray;
Lingfan Yu's avatar
Lingfan Yu committed
13
14
15

namespace dgl {

16
17
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
18
19
20
21
22
23
24
    const IdArray msg_ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray nids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));

    *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(msg_ids, vids, nids));
  });

25
26
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketingForEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Lingfan Yu's avatar
Lingfan Yu committed
27
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
28
29
30
31
32
33
34
35
    // XXX: better way to do arange?
    int64_t n_msgs = vids->shape[0];
    IdArray msg_ids = IdArray::Empty({n_msgs}, vids->dtype, vids->ctx);
    int64_t* mid_data = static_cast<int64_t*>(msg_ids->data);
    for (int64_t i = 0; i < n_msgs; ++i) {
        mid_data[i] = i;
    }
    *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(msg_ids, vids, vids));
Lingfan Yu's avatar
Lingfan Yu committed
36
37
  });

38
39
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketingForRecvNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Lingfan Yu's avatar
Lingfan Yu committed
40
41
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
42
43
44
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const auto& edges = gptr->InEdges(vids);
    *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(edges.id, edges.dst, vids));
Lingfan Yu's avatar
Lingfan Yu committed
45
46
  });

47
48
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketingForFullGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
49
50
51
52
53
54
55
56
57
58
59
    GraphHandle ghandle = args[0];
    const Graph* gptr = static_cast<Graph*>(ghandle);
    const auto& edges = gptr->Edges(false);
    int64_t n_vertices = gptr->NumVertices();
    IdArray nids = IdArray::Empty({n_vertices}, edges.dst->dtype, edges.dst->ctx);
    int64_t* nid_data = static_cast<int64_t*>(nids->data);
    for (int64_t i = 0; i < n_vertices; ++i) {
        nid_data[i] = i;
    }
    *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(edges.id, edges.dst, nids));
  });
60
}  // namespace dgl