utils.h 1.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
/*!
 *  Copyright (c) 2018 by Contributors
 * \file kernel/utils.h
 * \brief Kernel utilities
 */
#ifndef DGL_KERNEL_UTILS_H_
#define DGL_KERNEL_UTILS_H_

#include <minigun/csr.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>

#include <cstdlib>
#include <vector>

namespace dgl {
namespace kernel {
namespace utils {

/* !\brief Return an NDArray that represents none value. */
inline runtime::NDArray NoneArray() {
  return runtime::NDArray::Empty({}, DLDataType{kDLInt, 32, 1}, DLContext{kDLCPU, 0});
}

/* !\brief Return true if the NDArray is none. */
inline bool IsNoneArray(runtime::NDArray array) {
  return array->ndim == 0;
}

/*
 * !\brief Find number of threads is smaller than dim and max_nthrs
 * and is also the power of two.
 */
int FindNumThreads(int dim, int max_nthrs);

/*
 * !\brief Compute the total number of feature elements.
 */
int64_t ComputeXLength(runtime::NDArray feat_array);

/*
 * !\brief Compute the total number of elements in the array.
 */
int64_t NElements(const runtime::NDArray& array);

/*
 * !\brief Compute the product of the given vector.
 */
int64_t Prod(const std::vector<int64_t>& vec);

/*
 * !\brief Fill the array with constant value.
 */
template <int XPU, typename DType>
void Fill(const DLContext& ctx, DType* ptr, size_t length, DType val);

/*
 * !\brief Create minigun CSR from two ndarrays.
 */
template <typename Idx>
minigun::Csr<Idx> CreateCsr(runtime::NDArray indptr, runtime::NDArray indices) {
  minigun::Csr<Idx> csr;
  csr.row_offsets.data = static_cast<Idx*>(indptr->data);
  csr.row_offsets.length = indptr->shape[0];
  csr.column_indices.data = static_cast<Idx*>(indices->data);
  csr.column_indices.length = indices->shape[0];
  return csr;
}

}  // namespace utils
}  // namespace kernel
}  // namespace dgl

#endif  // DGL_KERNEL_UTILS_H_