csr_sort.cc 5.34 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/csr_sort.cc
 * \brief CSR sorting
 */
#include <dgl/array.h>
7
#include <dgl/runtime/parallel_for.h>
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <numeric>
#include <algorithm>
#include <vector>

namespace dgl {
namespace aten {
namespace impl {

///////////////////////////// CSRIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
  const IdType* indptr = csr.indptr.Ptr<IdType>();
  const IdType* indices = csr.indices.Ptr<IdType>();
  bool ret = true;
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  for (int64_t row = 0; row < csr.num_rows; ++row) {
    if (!ret)
      continue;
    for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) {
      if (indices[i - 1] > indices[i]) {
        ret = false;
        break;
      }
    }
  }
  return ret;
}

template bool CSRIsSorted<kDLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLCPU, int32_t>(CSRMatrix csr);

///////////////////////////// CSRSort /////////////////////////////

template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
  typedef std::pair<IdType, IdType> ShufflePair;
  const int64_t num_rows = csr->num_rows;
  const int64_t nnz = csr->indices->shape[0];
  const IdType* indptr_data = static_cast<IdType*>(csr->indptr->data);
  IdType* indices_data = static_cast<IdType*>(csr->indices->data);
  if (!CSRHasData(*csr)) {
    csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
  }
  IdType* eid_data = static_cast<IdType*>(csr->data->data);
52
53
54

  runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) {
    for (auto row = b; row < e; ++row) {
55
      const int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
56
      std::vector<ShufflePair> reorder_vec(num_cols);
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
      IdType *col = indices_data + indptr_data[row];
      IdType *eid = eid_data + indptr_data[row];

      for (int64_t i = 0; i < num_cols; i++) {
        reorder_vec[i].first = col[i];
        reorder_vec[i].second = eid[i];
      }
      std::sort(reorder_vec.begin(), reorder_vec.end(),
                [](const ShufflePair &e1, const ShufflePair &e2) {
                  return e1.first < e2.first;
                });
      for (int64_t i = 0; i < num_cols; i++) {
        col[i] = reorder_vec[i].first;
        eid[i] = reorder_vec[i].second;
      }
    }
73
74
  });

75
76
77
78
79
80
  csr->sorted = true;
}

template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
template <DLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
    const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
  const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
  const auto indices_data = static_cast<const IdType *>(csr.indices->data);
  const auto eid_array = aten::CSRHasData(csr) ? csr.data :
    aten::Range(0, csr.indices->shape[0], csr.indptr->dtype.bits, csr.indptr->ctx);
  const auto eid_data = static_cast<const IdType *>(csr.data->data);
  const auto tag_data = static_cast<const TagType *>(tag_array->data);
  const int64_t num_rows = csr.num_rows;

  NDArray tag_pos = NDArray::Empty({csr.num_rows, num_tags + 1},
      csr.indptr->dtype, csr.indptr->ctx);
  auto tag_pos_data = static_cast<IdType *>(tag_pos->data);
  std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);

  aten::CSRMatrix output(csr.num_rows, csr.num_cols,
                         csr.indptr.Clone(), csr.indices.Clone(),
                         eid_array.Clone(), csr.sorted);

  auto out_indices_data = static_cast<IdType *>(output.indices->data);
  auto out_eid_data = static_cast<IdType *>(output.data->data);

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  runtime::parallel_for(0, num_rows, [&](size_t b, size_t e) {
    for (auto src = b; src < e; ++src) {
      const IdType start = indptr_data[src];
      const IdType end = indptr_data[src + 1];

      auto tag_pos_row = tag_pos_data + src * (num_tags + 1);
      std::vector<IdType> pointer(num_tags, 0);

      for (IdType ptr = start ; ptr < end ; ++ptr) {
        const IdType dst = indices_data[ptr];
        const TagType tag = tag_data[dst];
        CHECK_LT(tag, num_tags);
        ++tag_pos_row[tag + 1];
      }  // count

      for (TagType tag = 1 ; tag <= num_tags; ++tag) {
        tag_pos_row[tag] += tag_pos_row[tag - 1];
      }  // cumulate

      for (IdType ptr = start ; ptr < end ; ++ptr) {
        const IdType dst = indices_data[ptr];
        const IdType eid = eid_data[ptr];
        const TagType tag = tag_data[dst];
        const IdType offset = tag_pos_row[tag] + pointer[tag];
        CHECK_LT(offset, tag_pos_row[tag + 1]);
        ++pointer[tag];

        out_indices_data[start + offset] = dst;
        out_eid_data[start + offset] = eid;
      }
134
    }
135
  });
136
137
138
139
140
141
142
143
144
145
146
147
148
  output.sorted = false;
  return std::make_pair(output, tag_pos);
}

template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int64_t>(
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int32_t>(
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int64_t>(
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int32_t>(
    const CSRMatrix &csr, const IdArray tag, int64_t num_tags);

149
150
151
}  // namespace impl
}  // namespace aten
}  // namespace dgl