test_aten.cc 4.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <gtest/gtest.h>
#include <dgl/array.h>
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

TEST(ArrayTest, TestCreate) {
  IdArray a = aten::NewIdArray(100, CTX, 32);
  ASSERT_EQ(a->dtype.bits, 32);
  ASSERT_EQ(a->shape[0], 100);

  a = aten::NewIdArray(0);
  ASSERT_EQ(a->shape[0], 0);

  std::vector<int64_t> vec = {2, 94, 232, 30};
  a = aten::VecToIdArray(vec, 32);
  ASSERT_EQ(Len(a), vec.size());
  ASSERT_EQ(a->dtype.bits, 32);
  for (int i = 0; i < Len(a); ++i) {
    ASSERT_EQ(Ptr<int32_t>(a)[i], vec[i]);
  }

  a = aten::VecToIdArray(std::vector<int32_t>());
  ASSERT_EQ(Len(a), 0);
};

TEST(ArrayTest, TestRange) {
  IdArray a = aten::Range(10, 10, 64, CTX);
  ASSERT_EQ(Len(a), 0);
  a = aten::Range(10, 20, 32, CTX);
  ASSERT_EQ(Len(a), 10);
  ASSERT_EQ(a->dtype.bits, 32);
  for (int i = 0; i < 10; ++i)
    ASSERT_EQ(Ptr<int32_t>(a)[i], i + 10);
};

TEST(ArrayTest, TestFull) {
  IdArray a = aten::Full(-100, 0, 32, CTX);
  ASSERT_EQ(Len(a), 0);
  a = aten::Full(-100, 13, 64, CTX);
  ASSERT_EQ(Len(a), 13);
  ASSERT_EQ(a->dtype.bits, 64);
  for (int i = 0; i < 13; ++i)
    ASSERT_EQ(Ptr<int64_t>(a)[i], -100);
};

TEST(ArrayTest, TestClone) {
  IdArray a = aten::NewIdArray(0);
  IdArray b = aten::Clone(a);
  ASSERT_EQ(Len(b), 0);

  a = aten::Range(0, 10, 32, CTX);
  b = aten::Clone(a);
  for (int i = 0; i < 10; ++i) {
    ASSERT_EQ(PI32(b)[i], i);
  }
  PI32(b)[0] = -1;
  for (int i = 0; i < 10; ++i) {
    ASSERT_EQ(PI32(a)[i], i);
  }
};

TEST(ArrayTest, TestAsNumBits) {
  IdArray a = aten::Range(0, 10, 32, CTX);
  a = aten::AsNumBits(a, 64);
  ASSERT_EQ(a->dtype.bits, 64);
  for (int i = 0; i < 10; ++i)
    ASSERT_EQ(PI64(a)[i], i);
};

template <typename IDX>
void _TestArith() {
  const int N = 100;
  IdArray a = aten::Full(-10, N, sizeof(IDX)*8, CTX);
  IdArray b = aten::Full(7, N, sizeof(IDX)*8, CTX);

  IdArray c = aten::Add(a, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -3);
  c = aten::Sub(a, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -17);
  c = aten::Mul(a, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -70);
  c = aten::Div(a, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -1);

  const int val = -3;
  c = aten::Add(a, val);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -13);
  c = aten::Sub(a, val);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -7);
  c = aten::Mul(a, val);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], 30);
  c = aten::Div(a, val);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], 3);
  c = aten::Add(val, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], 4);
  c = aten::Sub(val, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -10);
  c = aten::Mul(val, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], -21);
  c = aten::Div(val, b);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], 0);

  a = aten::Range(0, N, sizeof(IDX)*8, CTX);
  c = aten::LT(a, 50);
  for (int i = 0; i < N; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], (int)(i < 50));
}

TEST(ArrayTest, TestArith) {
  _TestArith<int32_t>();
  _TestArith<int64_t>();
};

template <typename IDX>
void _TestHStack() {
  IdArray a = aten::Range(0, 100, sizeof(IDX)*8, CTX);
  IdArray b = aten::Range(100, 200, sizeof(IDX)*8, CTX);
  IdArray c = aten::HStack(a, b);
  ASSERT_EQ(c->ndim, 1);
  ASSERT_EQ(c->shape[0], 200);
  for (int i = 0; i < 200; ++i)
    ASSERT_EQ(Ptr<IDX>(c)[i], i);
}

TEST(ArrayTest, TestHStack) {
  _TestHStack<int32_t>();
  _TestHStack<int64_t>();
}

template <typename IDX>
void _TestIndexSelect() {
  IdArray a = aten::Range(0, 100, sizeof(IDX)*8, CTX);
  ASSERT_EQ(aten::IndexSelect(a, 50), 50);
  IdArray b = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
  IdArray c = aten::IndexSelect(a, b);
  ASSERT_TRUE(ArrayEQ<IDX>(b, c));
}

TEST(ArrayTest, TestIndexSelect) {
  _TestIndexSelect<int32_t>();
  _TestIndexSelect<int64_t>();
}

template <typename IDX>
void _TestRelabel_() {
  IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
  IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, CTX);
  IdArray c = aten::Relabel_({a, b});
  IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
  IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, CTX);
  IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, CTX);
  ASSERT_TRUE(ArrayEQ<IDX>(a, ta));
  ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
  ASSERT_TRUE(ArrayEQ<IDX>(c, tc));
}

TEST(ArrayTest, TestRelabel_) {
  _TestRelabel_<int32_t>();
  _TestRelabel_<int64_t>();
}