"docs/vscode:/vscode.git/clone" did not exist on "358db43a77ba8cf09d620e8f874a6e8e882b75f6"
check.h 1.87 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/check.h
 * @brief DGL check utilities
5
6
7
8
9
 */
#ifndef DGL_ARRAY_CHECK_H_
#define DGL_ARRAY_CHECK_H_

#include <dgl/array.h>
10
11
#include <dgl/runtime/ndarray.h>

12
#include <string>
13
#include <vector>
14
15
16
17
18
19

namespace dgl {
namespace aten {

// Check whether the given arguments have the same context.
inline void CheckCtx(
20
    const DGLContext& ctx, const std::vector<NDArray>& arrays,
21
22
    const std::vector<std::string>& names) {
  for (size_t i = 0; i < arrays.size(); ++i) {
23
    if (IsNullArray(arrays[i])) continue;
24
    CHECK_EQ(ctx, arrays[i]->ctx)
25
26
        << "Expected device context " << ctx << ". But got " << arrays[i]->ctx
        << " for " << names[i] << ".";
27
28
29
30
31
  }
}

// Check whether input tensors are contiguous.
inline void CheckContiguous(
32
    const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {
33
  for (size_t i = 0; i < arrays.size(); ++i) {
34
    if (IsNullArray(arrays[i])) continue;
35
    CHECK(arrays[i].IsContiguous())
36
        << "Expect " << names[i] << " to be a contiguous tensor";
37
38
39
40
41
  }
}

// Check whether input tensors have valid shape.
inline void CheckShape(
42
43
    const std::vector<uint64_t>& gdim, const std::vector<int>& uev_idx,
    const std::vector<NDArray>& arrays, const std::vector<std::string>& names) {
44
  for (size_t i = 0; i < arrays.size(); ++i) {
45
    if (IsNullArray(arrays[i])) continue;
46
    CHECK_GE(arrays[i]->ndim, 2)
47
48
49
50
        << "Expect " << names[i] << " to have ndim >= 2, "
        << "Note that for scalar feature we expand its "
        << "dimension with an additional dimension of "
        << "length one.";
51
    CHECK_EQ(gdim[uev_idx[i]], arrays[i]->shape[0])
52
53
54
        << "Expect " << names[i] << " to have size " << gdim[uev_idx[i]]
        << " on the first dimension, "
        << "but got " << arrays[i]->shape[0];
55
56
57
58
59
60
61
  }
}

}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CHECK_H_