coo_coalesce.cc 1.47 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cpu/coo_coalesce.cc
 * \brief COO coalescing
 */

#include <dgl/array.h>
8

9
10
11
12
13
14
15
16
#include <vector>

namespace dgl {

namespace aten {

namespace impl {

17
template <DGLDeviceType XPU, typename IdType>
18
19
20
21
22
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
  const int64_t nnz = coo.row->shape[0];
  const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
  const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);

23
  if (!coo.row_sorted || !coo.col_sorted) coo = COOSort(coo, true);
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

  std::vector<IdType> new_row, new_col, count;
  IdType prev_row = -1, prev_col = -1;
  for (int64_t i = 0; i < nnz; ++i) {
    const IdType curr_row = coo_row_data[i];
    const IdType curr_col = coo_col_data[i];
    if (curr_row == prev_row && curr_col == prev_col) {
      ++count[count.size() - 1];
    } else {
      new_row.push_back(curr_row);
      new_col.push_back(curr_col);
      count.push_back(1);
      prev_row = curr_row;
      prev_col = curr_col;
    }
  }

  COOMatrix coo_result = COOMatrix{
42
43
44
45
46
47
      coo.num_rows,
      coo.num_cols,
      NDArray::FromVector(new_row),
      NDArray::FromVector(new_col),
      NullArray(),
      true};
48
49
50
  return std::make_pair(coo_result, NDArray::FromVector(count));
}

51
52
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);
53
54
55
56

};  // namespace impl
};  // namespace aten
};  // namespace dgl