coo_linegraph.cc 1.53 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/coo_line_graph.cc
 * @brief COO LineGraph
5
6
7
 */

#include <dgl/array.h>
8

9
10
#include <algorithm>
#include <iterator>
11
12
#include <numeric>
#include <vector>
13
14
15
16
17

namespace dgl {
namespace aten {
namespace impl {

18
template <DGLDeviceType XPU, typename IdType>
19
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
20
21
22
  const int64_t nnz = coo.row->shape[0];
  IdType* coo_row = coo.row.Ptr<IdType>();
  IdType* coo_col = coo.col.Ptr<IdType>();
23
24
25
  IdArray data = COOHasData(coo)
                     ? coo.data
                     : Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);
26
27
28
29
30
31
32
33
34
  IdType* data_data = data.Ptr<IdType>();
  std::vector<IdType> new_row;
  std::vector<IdType> new_col;

  for (int64_t i = 0; i < nnz; ++i) {
    IdType u = coo_row[i];
    IdType v = coo_col[i];
    for (int64_t j = 0; j < nnz; ++j) {
      // no self-loop
35
      if (i == j) continue;
36
37
38
39
40
41
42
43
44
45

      // succ_u == v
      // if not backtracking succ_u != u
      if (v == coo_row[j] && (backtracking || u != coo_col[j])) {
        new_row.push_back(data_data[i]);
        new_col.push_back(data_data[j]);
      }
    }
  }

46
47
48
  COOMatrix res = COOMatrix(
      nnz, nnz, NDArray::FromVector(new_row), NDArray::FromVector(new_col),
      NullArray(), false, false);
49
50
51
  return res;
}

52
53
54
55
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(
    const COOMatrix& coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int64_t>(
    const COOMatrix& coo, bool backtracking);
56
57
58
59

}  // namespace impl
}  // namespace aten
}  // namespace dgl