test_unit_graph.cc 13 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2019 by Contributors
 * \file test_unit_graph.cc
 * \brief Test UnitGraph
 */
6
7
8
#include "../../src/graph/unit_graph.h"
#include "./../src/graph/heterograph.h"
#include "./common.h"
9
#include <dgl/array.h>
10
11
12
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <gtest/gtest.h>
13
#include <memory>
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <vector>

using namespace dgl;
using namespace dgl::runtime;

template <typename IdType>
aten::CSRMatrix CSR1(DLContext ctx) {
  /*
   * G = [[0, 0, 1],
   *      [1, 0, 1],
   *      [0, 1, 0],
   *      [1, 0, 1]]
   */
  IdArray g_indptr =
    aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 4, 6}), sizeof(IdType)*8, CTX);
  IdArray g_indices =
    aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 0, 2}), sizeof(IdType)*8, CTX);

  const aten::CSRMatrix &csr_a = aten::CSRMatrix(
    4,
    3,
    g_indptr,
    g_indices,
    aten::NullArray(),
    false);
  return csr_a;
}

template aten::CSRMatrix CSR1<int32_t>(DLContext ctx);
template aten::CSRMatrix CSR1<int64_t>(DLContext ctx);

template <typename IdType>
aten::COOMatrix COO1(DLContext ctx) {
  /*
   * G = [[1, 1, 0],
   *      [0, 1, 0]]
   */
  IdArray g_row =
    aten::VecToIdArray(std::vector<IdType>({0, 0, 1}), sizeof(IdType)*8, CTX);
  IdArray g_col =
    aten::VecToIdArray(std::vector<IdType>({0, 1, 1}), sizeof(IdType)*8, CTX);
  const aten::COOMatrix &coo = aten::COOMatrix(
    2,
    3,
    g_row,
    g_col,
    aten::NullArray(),
    true,
    true);

  return coo;
}

template aten::COOMatrix COO1<int32_t>(DLContext ctx);
template aten::COOMatrix COO1<int64_t>(DLContext ctx);

template <typename IdType>
void _TestUnitGraph(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

75
76
77
78
79
80
81
82
83
84
85
  auto g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);

  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);

  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);

  auto src = aten::VecToIdArray<int64_t>({1, 2, 5, 3});
  auto dst = aten::VecToIdArray<int64_t>({1, 6, 2, 6});
86
  auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE);
87
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
88
  auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, COO_CODE);
89
90
  auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);
91
  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSR_CODE | COO_CODE);
92
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
93
  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSR_CODE | COO_CODE);
94
95
  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);
96
  mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSC_CODE | COO_CODE);
97
  ASSERT_EQ(mg->GetCreatedFormats(), 1);
98
  hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSC_CODE | COO_CODE);
99
100
101
  img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
  ASSERT_TRUE(img != nullptr);

102
103
  g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);
104

105
106
  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);
107

108
109
  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);
110
111
112
113
114
115
116
}

template <typename IdType>
void _TestUnitGraph_GetInCSR(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

117
  auto g = CreateFromCSC(2, csr);
118
119
120
  auto in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_cols);
121
  ASSERT_EQ(g->GetCreatedFormats(), 4);
122
123

  // test out csr
124
  g = CreateFromCSR(2, csr);
125
  auto g_ptr = g->GetGraphInFormat(CSC_CODE);
126
127
128
  in_csr_matrix = g_ptr->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
129
  ASSERT_EQ(g->GetCreatedFormats(), 2);
130
131
132
  in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
133
  ASSERT_EQ(g->GetCreatedFormats(), 6);
134
135

  // test out coo
136
  g = CreateFromCOO(2, coo);
137
  g_ptr = g->GetGraphInFormat(CSC_CODE);
138
139
140
  in_csr_matrix = g_ptr->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
141
  ASSERT_EQ(g->GetCreatedFormats(), 1);
142
143
144
145

  in_csr_matrix = g->GetCSCMatrix(0);
  ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
  ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
146
  ASSERT_EQ(g->GetCreatedFormats(), 5);
147
148
149
150
151
152
153
}

template <typename IdType>
void _TestUnitGraph_GetOutCSR(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

154
  auto g = CreateFromCSC(2, csr);
155
  auto g_ptr = g->GetGraphInFormat(CSR_CODE);
156
157
158
  auto out_csr_matrix = g_ptr->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
159
  ASSERT_EQ(g->GetCreatedFormats(), 4);
160
161
162
  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
163
  ASSERT_EQ(g->GetCreatedFormats(), 6);
164
165

  // test out csr
166
  g = CreateFromCSR(2, csr);
167
168
169
  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, csr.num_cols);
170
  ASSERT_EQ(g->GetCreatedFormats(), 2);
171
172

  // test out coo
173
  g = CreateFromCOO(2, coo);
174
  g_ptr = g->GetGraphInFormat(CSR_CODE);
175
176
177
  out_csr_matrix = g_ptr->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
178
  ASSERT_EQ(g->GetCreatedFormats(), 1);
179
180
181
182

  out_csr_matrix = g->GetCSRMatrix(0);
  ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
183
  ASSERT_EQ(g->GetCreatedFormats(), 3);
184
185
186
187
188
189
190
}

template <typename IdType>
void _TestUnitGraph_GetCOO(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

191
  auto g = CreateFromCSC(2, csr);
192
  auto g_ptr = g->GetGraphInFormat(COO_CODE);
193
194
195
  auto out_coo_matrix = g_ptr->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
196
  ASSERT_EQ(g->GetCreatedFormats(), 4);
197
198
199
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
200
  ASSERT_EQ(g->GetCreatedFormats(), 5);
201
202

  // test out csr
203
  g = CreateFromCSR(2, csr);
204
  g_ptr = g->GetGraphInFormat(COO_CODE);
205
206
207
  out_coo_matrix = g_ptr->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
208
  ASSERT_EQ(g->GetCreatedFormats(), 2);
209
210
211
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
212
  ASSERT_EQ(g->GetCreatedFormats(), 3);
213
214

  // test out coo
215
  g = CreateFromCOO(2, coo);
216
217
218
  out_coo_matrix = g->GetCOOMatrix(0);
  ASSERT_EQ(out_coo_matrix.num_rows, coo.num_rows);
  ASSERT_EQ(out_coo_matrix.num_cols, coo.num_cols);
219
  ASSERT_EQ(g->GetCreatedFormats(), 1);
220
221
222
223
224
225
226
}

template <typename IdType>
void _TestUnitGraph_Reserve(DLContext ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
  const aten::COOMatrix &coo = COO1<IdType>(ctx);

227
228
229
230
231
  auto g = CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);
  auto r_g =
      std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 2);
232
233
234
235
236
  aten::CSRMatrix g_in_csr = g->GetCSCMatrix(0);
  aten::CSRMatrix r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  aten::CSRMatrix g_out_csr = g->GetCSRMatrix(0);
237
238
  ASSERT_EQ(g->GetCreatedFormats(), 6);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
239
240
241
242
  aten::CSRMatrix r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
  aten::COOMatrix g_coo = g->GetCOOMatrix(0);
243
244
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
245
  aten::COOMatrix r_g_coo = r_g->GetCOOMatrix(0);
246
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
247
248
249
250
251
252
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));

  // test out csr
253
254
255
256
  g = CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);
  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 4);
257
258
259
260
261
  g_out_csr = g->GetCSRMatrix(0);
  r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
  g_in_csr = g->GetCSCMatrix(0);
262
263
  ASSERT_EQ(g->GetCreatedFormats(), 6);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
264
265
266
267
  r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  g_coo = g->GetCOOMatrix(0);
268
269
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 6);
270
  r_g_coo = r_g->GetCOOMatrix(0);
271
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
272
273
274
275
276
277
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
  ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));

  // test out coo
278
279
280
281
  g = CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);
  r_g = std::dynamic_pointer_cast<UnitGraph>(g->GetRelationGraph(0))->Reverse();
  ASSERT_EQ(r_g->GetCreatedFormats(), 1);
282
283
284
285
286
287
288
  g_coo = g->GetCOOMatrix(0);
  r_g_coo = r_g->GetCOOMatrix(0);
  ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
  ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
  ASSERT_TRUE(g_coo.row->data == r_g_coo.col->data);
  ASSERT_TRUE(g_coo.col->data == r_g_coo.row->data);
  g_in_csr = g->GetCSCMatrix(0);
289
290
  ASSERT_EQ(g->GetCreatedFormats(), 5);
  ASSERT_EQ(r_g->GetCreatedFormats(), 3);
291
292
293
294
  r_g_out_csr = r_g->GetCSRMatrix(0);
  ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
  ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
  g_out_csr = g->GetCSRMatrix(0);
295
296
  ASSERT_EQ(g->GetCreatedFormats(), 7);
  ASSERT_EQ(r_g->GetCreatedFormats(), 7);
297
298
299
300
301
  r_g_in_csr = r_g->GetCSCMatrix(0);
  ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
  ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
}

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
template <typename IdType>
void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
                           const DGLContext &dst_ctx) {
  const aten::CSRMatrix &csr = CSR1<IdType>(src_ctx);
  const aten::COOMatrix &coo = COO1<IdType>(src_ctx);

  auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
  auto stream = device->CreateStream(dst_ctx);

  auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 4);
  auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
  device->StreamSync(dst_ctx, stream);
  ASSERT_EQ(cg->GetCreatedFormats(), 4);

  g = dgl::UnitGraph::CreateFromCSR(2, csr);
  ASSERT_EQ(g->GetCreatedFormats(), 2);
  cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
  device->StreamSync(dst_ctx, stream);
  ASSERT_EQ(cg->GetCreatedFormats(), 2);

  g = dgl::UnitGraph::CreateFromCOO(2, coo);
  ASSERT_EQ(g->GetCreatedFormats(), 1);
  cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
  device->StreamSync(dst_ctx, stream);
  ASSERT_EQ(cg->GetCreatedFormats(), 1);
}

TEST(UniGraphTest, TestUnitGraph_CopyTo) {
  _TestUnitGraph_CopyTo<int32_t>(CPU, CPU);
  _TestUnitGraph_CopyTo<int64_t>(CPU, CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_CopyTo<int32_t>(CPU, GPU);
  _TestUnitGraph_CopyTo<int32_t>(GPU, GPU);
  _TestUnitGraph_CopyTo<int32_t>(GPU, CPU);
  _TestUnitGraph_CopyTo<int64_t>(CPU, GPU);
  _TestUnitGraph_CopyTo<int64_t>(GPU, GPU);
  _TestUnitGraph_CopyTo<int64_t>(GPU, CPU);
#endif
}

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
TEST(UniGraphTest, TestUnitGraph_Create) {
  _TestUnitGraph<int32_t>(CPU);
  _TestUnitGraph<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph<int32_t>(GPU);
  _TestUnitGraph<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetInCSR) {
  _TestUnitGraph_GetInCSR<int32_t>(CPU);
  _TestUnitGraph_GetInCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetInCSR<int32_t>(GPU);
  _TestUnitGraph_GetInCSR<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetOutCSR) {
  _TestUnitGraph_GetOutCSR<int32_t>(CPU);
  _TestUnitGraph_GetOutCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetOutCSR<int32_t>(GPU);
  _TestUnitGraph_GetOutCSR<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_GetCOO) {
  _TestUnitGraph_GetCOO<int32_t>(CPU);
  _TestUnitGraph_GetCOO<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_GetCOO<int32_t>(GPU);
  _TestUnitGraph_GetCOO<int64_t>(GPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_Reserve) {
  _TestUnitGraph_Reserve<int32_t>(CPU);
  _TestUnitGraph_Reserve<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestUnitGraph_Reserve<int32_t>(GPU);
  _TestUnitGraph_Reserve<int64_t>(GPU);
#endif
}