scheduler_apis.cc 1.29 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file scheduler/scheduler_apis.cc
 * \brief DGL scheduler APIs
 */
6
#include <dgl/array.h>
Lingfan Yu's avatar
Lingfan Yu committed
7
8
#include <dgl/graph.h>
#include <dgl/scheduler.h>
9
#include "../c_api_common.h"
10
#include "../array/cpu/array_utils.h"
Lingfan Yu's avatar
Lingfan Yu committed
11

12
13
14
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLRetValue;
using dgl::runtime::NDArray;
Lingfan Yu's avatar
Lingfan Yu committed
15
16
17

namespace dgl {

18
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLDegreeBucketing")
19
  .set_body([](DGLArgs args, DGLRetValue* rv) {
20
21
22
    const IdArray msg_ids = args[0];
    const IdArray vids = args[1];
    const IdArray nids = args[2];
23
24
25
26
27
28
    CHECK_SAME_DTYPE(msg_ids, vids);
    CHECK_SAME_DTYPE(msg_ids, nids);
    ATEN_ID_TYPE_SWITCH(msg_ids->dtype, IdType, {
      *rv = ConvertNDArrayVectorToPackedFunc(
        sched::DegreeBucketing<IdType>(msg_ids, vids, nids));
    });
29
30
  });

31
DGL_REGISTER_GLOBAL("runtime.degree_bucketing._CAPI_DGLGroupEdgeByNodeDegree")
32
.set_body([] (DGLArgs args, DGLRetValue* rv) {
33
34
35
    const IdArray uids = args[0];
    const IdArray vids = args[1];
    const IdArray eids = args[2];
36
37
38
39
40
41
    CHECK_SAME_DTYPE(uids, vids);
    CHECK_SAME_DTYPE(uids, eids);
    ATEN_ID_TYPE_SWITCH(uids->dtype, IdType, {
      *rv = ConvertNDArrayVectorToPackedFunc(
        sched::GroupEdgeByNodeDegree<IdType>(uids, vids, eids));
    });
Lingfan Yu's avatar
Lingfan Yu committed
42
43
  });

44
}  // namespace dgl