test_unit_graph.cc 11.5 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file test_unit_graph.cc
 * \brief Test UnitGraph
 */
#include <gtest/gtest.h>
#include <dgl/array.h>
8
#include <memory>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <vector>
#include <dgl/immutable_graph.h>
#include "./common.h"
#include "./../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"

using namespace dgl;
using namespace dgl::runtime;

template <typename IdType>
aten::CSRMatrix CSR1(DLContext ctx) {
  /*
   * G = [[0, 0, 1],
   *      [1, 0, 1],
   *      [0, 1, 0],
   *      [1, 0, 1]]
   */
  IdArray g_indptr =
    aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 4, 6}), sizeof(IdType)*8, CTX);
  IdArray g_indices =
    aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 0, 2}), sizeof(IdType)*8, CTX);

  const aten::CSRMatrix &csr_a = aten::CSRMatrix(
    4,
    3,
    g_indptr,
    g_indices,
    aten::NullArray(),
    false);
  return csr_a;
}

template aten::CSRMatrix CSR1<int32_t>(DLContext ctx);
template aten::CSRMatrix CSR1<int64_t>(DLContext ctx);

template <typename IdType>
aten::COOMatrix COO1(DLContext ctx) {
  /*
   * G = [[1, 1, 0],
   *      [0, 1, 0]]
   */
  IdArray g_row =
    aten::VecToIdArray(std::vector<IdType>({0, 0, 1}), sizeof(IdType)*8, CTX);
  IdArray g_col =
    aten::VecToIdArray(std::vector<IdType>({0, 1, 1}), sizeof(IdType)*8, CTX);
  const aten::COOMatrix &coo = aten::COOMatrix(
    2,
    3,
    g_row,
    g_col,
    aten::NullArray(),
    true,
    true);

  return coo;
}

template aten::COOMatrix COO1<int32_t>(DLContext ctx);
template aten::COOMatrix COO1<int64_t>(DLContext ctx);

template <typename IdType>
void _TestUnitGraph(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

74
75
76
77
78
79
80
81
82
83
84
85
86
87
  auto g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);

  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);

  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);

  auto src = aten::VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = aten::VecToIdArray<int64_t>({1, 6, 2, 6});
  auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, coo_code);
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
  auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, coo_code);
88
89
  auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);
90
91
92
  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, csr_code | coo_code);
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, csr_code | coo_code);
93
94
  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);
95
96
97
  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, csc_code | coo_code);
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, csc_code | coo_code);
98
99
100
  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);

101
102
  g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);
103

104
105
  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);
106

107
108
  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);
109
110
111
112
113
114
115
}

template <typename IdType>
void _TestUnitGraph_GetInCSR(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

116
  auto g = CreateFromCSC(2, csr);
117
118
119
  auto in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_cols);
120
  ASSERT_EQ(g->GetCreatedFormats(), 4);
121
122

  // test out csr
123
124
  g = CreateFromCSR(2, csr);
  auto g_ptr = g->GetGraphInFormat(csc_code);
125
126
127
  in_csr_matrix = g_ptr->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
128
  ASSERT_EQ(g->GetCreatedFormats(), 2);
129
130
131
  in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
132
  ASSERT_EQ(g->GetCreatedFormats(), 6);
133
134

  // test out coo
135
136
  g = CreateFromCOO(2, coo);
  g_ptr = g->GetGraphInFormat(csc_code);
137
138
139
  in_csr_matrix = g_ptr->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
140
  ASSERT_EQ(g->GetCreatedFormats(), 1);
141
142
143
144

  in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
145
  ASSERT_EQ(g->GetCreatedFormats(), 5);
146
147
148
149
150
151
152
}

template <typename IdType>
void _TestUnitGraph_GetOutCSR(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

153
154
  auto g = CreateFromCSC(2, csr);
  auto g_ptr = g->GetGraphInFormat(csr_code);
155
156
157
  auto out_csr_matrix = g_ptr->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
158
  ASSERT_EQ(g->GetCreatedFormats(), 4);
159
160
161
  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
162
  ASSERT_EQ(g->GetCreatedFormats(), 6);
163
164

  // test out csr
165
  g = CreateFromCSR(2, csr);
166
167
168
  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_cols);
169
  ASSERT_EQ(g->GetCreatedFormats(), 2);
170
171

  // test out coo
172
173
  g = CreateFromCOO(2, coo);
  g_ptr = g->GetGraphInFormat(csr_code);
174
175
176
  out_csr_matrix = g_ptr->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
177
  ASSERT_EQ(g->GetCreatedFormats(), 1);
178
179
180
181

  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
182
  ASSERT_EQ(g->GetCreatedFormats(), 3);
183
184
185
186
187
188
189
}

template <typename IdType>
void _TestUnitGraph_GetCOO(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

190
191
  auto g = CreateFromCSC(2, csr);
  auto g_ptr = g->GetGraphInFormat(coo_code);
192
193
194
  auto out_coo_matrix = g_ptr->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
195
  ASSERT_EQ(g->GetCreatedFormats(), 4);
196
197
198
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
199
  ASSERT_EQ(g->GetCreatedFormats(), 5);
200
201

  // test out csr
202
203
  g = CreateFromCSR(2, csr);
  g_ptr = g->GetGraphInFormat(coo_code);
204
205
206
  out_coo_matrix = g_ptr->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
207
  ASSERT_EQ(g->GetCreatedFormats(), 2);
208
209
210
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
211
  ASSERT_EQ(g->GetCreatedFormats(), 3);
212
213

  // test out coo
214
  g = CreateFromCOO(2, coo);
215
216
217
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, coo.num_cols);
218
  ASSERT_EQ(g->GetCreatedFormats(), 1);
219
220
221
222
223
224
225
}

template <typename IdType>
void _TestUnitGraph_Reserve(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

226
227
228
229
230
  auto g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);
  auto r_g =
      std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 2);
231
232
233
234
235
  aten::CSRMatrix g_in_csr = g->GetCSCMatrix(0);
  aten::CSRMatrix r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  aten::CSRMatrix g_out_csr = g->GetCSRMatrix(0);
236
237
  ASSERT_EQ(g->GetCreatedFormats(), 6);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
238
239
240
241
  aten::CSRMatrix r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
  aten::COOMatrix g_coo = g->GetCOOMatrix(0);
242
243
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
244
  aten::COOMatrix r_g_coo = r_g->GetCOOMatrix(0);
245
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
246
247
248
249
250
251
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));

  // test out csr
252
253
254
255
  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);
  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 4);
256
257
258
259
260
  g_out_csr = g->GetCSRMatrix(0);
  r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
  g_in_csr = g->GetCSCMatrix(0);
261
262
  ASSERT_EQ(g->GetCreatedFormats(), 6);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
263
264
265
266
  r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  g_coo = g->GetCOOMatrix(0);
267
268
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
269
  r_g_coo = r_g->GetCOOMatrix(0);
270
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
271
272
273
274
275
276
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));

  // test out coo
277
278
279
280
  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);
  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 1);
281
282
283
284
285
286
287
  g_coo = g->GetCOOMatrix(0);
  r_g_coo = r_g->GetCOOMatrix(0);
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(g_coo.row->data == r_g_coo.col->data);
  ASSERT_TRUE(g_coo.col->data == r_g_coo.row->data);
  g_in_csr = g->GetCSCMatrix(0);
288
289
  ASSERT_EQ(g->GetCreatedFormats(), 5);
  ASSERT_EQ(r_g->GetCreatedFormats(), 3);
290
291
292
293
  r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  g_out_csr = g->GetCSRMatrix(0);
294
295
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
  r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
}

TEST(UniGraphTest, TestUnitGraph_Create) {
  _TestUnitGraph<int32_t>(CPU);
  _TestUnitGraph<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph<int32_t>(GPU);
  _TestUnitGraph<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetInCSR) {
  _TestUnitGraph_GetInCSR<int32_t>(CPU);
  _TestUnitGraph_GetInCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetInCSR<int32_t>(GPU);
  _TestUnitGraph_GetInCSR<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetOutCSR) {
  _TestUnitGraph_GetOutCSR<int32_t>(CPU);
  _TestUnitGraph_GetOutCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetOutCSR<int32_t>(GPU);
  _TestUnitGraph_GetOutCSR<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetCOO) {
  _TestUnitGraph_GetCOO<int32_t>(CPU);
  _TestUnitGraph_GetCOO<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetCOO<int32_t>(GPU);
  _TestUnitGraph_GetCOO<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_Reserve) {
  _TestUnitGraph_Reserve<int32_t>(CPU);
  _TestUnitGraph_Reserve<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_Reserve<int32_t>(GPU);
  _TestUnitGraph_Reserve<int64_t>(GPU);
#endif
}