utils.cc 1.28 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file utils.cc
 * \brief DGL util functions
 */

7
#include <dmlc/omp.h>
8

9
10

#include <dgl/aten/coo.h>
11
#include <dgl/packed_func_ext.h>
12
13
#include <utility>

14
#include "../c_api_common.h"
15
16
#include "../array/array_op.h"

17
18

using namespace dgl::runtime;
19
using namespace dgl::aten::impl;
20
21
22
23
24
25
26
27
28

namespace dgl {

DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLSetOMPThreads")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    int num_threads = args[0];
    omp_set_num_threads(num_threads);
  });

29
30
31
32
DGL_REGISTER_GLOBAL("utils.internal._CAPI_DGLGetOMPThreads")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    *rv = omp_get_max_threads();
  });
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

DGL_REGISTER_GLOBAL("utils.checks._CAPI_DGLCOOIsSorted")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    IdArray src = args[0];
    IdArray dst = args[1];
    int64_t num_src = args[2];
    int64_t num_dst = args[3];

    bool row_sorted, col_sorted;
    std::tie(row_sorted, col_sorted) = COOIsSorted(
        aten::COOMatrix(num_src, num_dst, src, dst));

    // make sure col_sorted is only true when row_sorted is true
    assert(!(!row_sorted && col_sorted));

    // 0 for unosrted, 1 for row sorted, 2 for row and col sorted
    int64_t sorted_status = row_sorted + col_sorted;
    *rv = sorted_status;
  });

53
}  // namespace dgl