utils.cc 1.32 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file utils.cc
 * \brief DGL util functions
 */

7
#include <dgl/aten/coo.h>
8
#include <dgl/packed_func_ext.h>
9
10
#include <dmlc/omp.h>

11
12
13
#include <utility>

#include "../array/array_op.h"
14
#include "../c_api_common.h"
15
16

using namespace dgl::runtime;
17
using namespace dgl::aten::impl;
18
19
20
21

namespace dgl {

DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
22
23
24
25
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      int num_threads = args[0];
      omp_set_num_threads(num_threads);
    });
26

27
DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLGetOMPThreads")
28
29
30
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      *rv = omp_get_max_threads();
    });
31
32

DGL_REGISTER_GLOBAL("utils.checks._CAPI_DGLCOOIsSorted")
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      IdArray src = args[0];
      IdArray dst = args[1];
      int64_t num_src = args[2];
      int64_t num_dst = args[3];

      bool row_sorted, col_sorted;
      std::tie(row_sorted, col_sorted) =
          COOIsSorted(aten::COOMatrix(num_src, num_dst, src, dst));

      // make sure col_sorted is only true when row_sorted is true
      assert(!(!row_sorted && col_sorted));

      // 0 for unosrted, 1 for row sorted, 2 for row and col sorted
      int64_t sorted_status = row_sorted + col_sorted;
      *rv = sorted_status;
    });
50

51
}  // namespace dgl