scheduler_apis.cc 1.24 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file scheduler/scheduler_apis.cc
 * \brief DGL scheduler APIs
 */
Lingfan Yu's avatar
Lingfan Yu committed
6
7
#include <dgl/graph.h>
#include <dgl/scheduler.h>
8
#include "../c_api_common.h"
Lingfan Yu's avatar
Lingfan Yu committed
9

10
11
12
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue;
using dgl::runtime::NDArray;
Lingfan Yu's avatar
Lingfan Yu committed
13
14
15

namespace dgl {

16
17
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
18
19
20
21
22
23
    const IdArray msg_ids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
    const IdArray nids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
    *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(msg_ids, vids, nids));
  });

24
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
25
.set_body([] (DGLArgs args, DGLRetValue* rv) {
26
    const IdArray uids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[0]));
27
    const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
28
29
30
    const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
    *rv = ConvertNDArrayVectorToPackedFunc(
            sched::GroupEdgeByNodeDegree(uids, vids, eids));
Lingfan Yu's avatar
Lingfan Yu committed
31
32
  });

33
}  // namespace dgl