coo_sort.cc 6.06 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/coo_sort.cc
 * \brief COO sorting
 */
#include <dgl/array.h>
7
8
9
#ifdef PARALLEL_ALGORITHMS
#include <parallel/algorithm>
#endif
10
11
12
#include <numeric>
#include <algorithm>
#include <vector>
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <iterator>
#include <tuple>

namespace {

template <typename IdType>
struct TupleRef {
  TupleRef() = delete;
  TupleRef(const TupleRef& other) = default;
  TupleRef(TupleRef&& other) = default;
  TupleRef(IdType *const r, IdType *const c, IdType *const d)
    : row(r), col(c), data(d) {}

  TupleRef& operator=(const TupleRef& other) {
    *row = *other.row;
    *col = *other.col;
    *data = *other.data;
    return *this;
  }
  TupleRef& operator=(const std::tuple<IdType, IdType, IdType>& val) {
    *row = std::get<0>(val);
    *col = std::get<1>(val);
    *data = std::get<2>(val);
    return *this;
  }

  operator std::tuple<IdType, IdType, IdType>() const {
    return std::make_tuple(*row, *col, *data);
  }

  void Swap(const TupleRef& other) const {
    std::swap(*row, *other.row);
    std::swap(*col, *other.col);
    std::swap(*data, *other.data);
  }

  IdType *row, *col, *data;
};

using std::swap;
template <typename IdType>
void swap(const TupleRef<IdType>& r1, const TupleRef<IdType>& r2) {
  r1.Swap(r2);
}

template <typename IdType>
struct CooIterator : public std::iterator<std::random_access_iterator_tag,
                                          std::tuple<IdType, IdType, IdType>,
                                          std::ptrdiff_t,
                                          std::tuple<IdType*, IdType*, IdType*>,
                                          TupleRef<IdType>> {
  CooIterator() = default;
  CooIterator(const CooIterator& other) = default;
  CooIterator(CooIterator&& other) = default;
  CooIterator(IdType *r, IdType *c, IdType *d): row(r), col(c), data(d) {}

  CooIterator& operator=(const CooIterator& other) = default;
  CooIterator& operator=(CooIterator&& other) = default;
  ~CooIterator() = default;

  bool operator==(const CooIterator& other) const {
    return row == other.row;
  }

  bool operator!=(const CooIterator& other) const {
    return row != other.row;
  }

  bool operator<(const CooIterator& other) const {
    return row < other.row;
  }

  bool operator>(const CooIterator& other) const {
    return row > other.row;
  }

  bool operator<=(const CooIterator& other) const {
    return row <= other.row;
  }

  bool operator>=(const CooIterator& other) const {
    return row >= other.row;
  }

  CooIterator& operator+=(const std::ptrdiff_t& movement) {
    row += movement;
    col += movement;
    data += movement;
    return *this;
  }

  CooIterator& operator-=(const std::ptrdiff_t& movement) {
    row -= movement;
    col -= movement;
    data -= movement;
    return *this;
  }

  CooIterator& operator++() {
    return operator+=(1);
  }

  CooIterator& operator--() {
    return operator-=(1);
  }

  CooIterator operator++(int) {
    CooIterator ret(*this);
    operator++();
    return ret;
  }

  CooIterator operator--(int) {
    CooIterator ret(*this);
    operator--();
    return ret;
  }

  CooIterator operator+(const std::ptrdiff_t& movement) const {
    CooIterator ret(*this);
    ret += movement;
    return ret;
  }

  CooIterator operator-(const std::ptrdiff_t& movement) const {
    CooIterator ret(*this);
    ret -= movement;
    return ret;
  }

  std::ptrdiff_t operator-(const CooIterator& other) const {
    return row - other.row;
  }

  TupleRef<IdType> operator*() const {
    return TupleRef<IdType>(row, col, data);
  }
  TupleRef<IdType> operator*() {
    return TupleRef<IdType>(row, col, data);
  }

154
155
156
157
158
  // required for random access iterators in VS2019
  TupleRef<IdType> operator[](size_t offset) const {
    return TupleRef<IdType>(row + offset, col + offset, data + offset);
  }

159
160
161
162
  IdType *row, *col, *data;
};

}  // namespace
163
164
165
166
167

namespace dgl {
namespace aten {
namespace impl {

168
169
///////////////////////////// COOSort_ /////////////////////////////

170
template <DGLDeviceType XPU, typename IdType>
171
172
173
174
175
176
177
178
179
180
181
void COOSort_(COOMatrix* coo, bool sort_column) {
  const int64_t nnz = coo->row->shape[0];
  IdType* coo_row = coo->row.Ptr<IdType>();
  IdType* coo_col = coo->col.Ptr<IdType>();
  if (!COOHasData(*coo))
    coo->data = aten::Range(0, nnz, coo->row->dtype.bits, coo->row->ctx);
  IdType* coo_data = coo->data.Ptr<IdType>();

  typedef std::tuple<IdType, IdType, IdType> Tuple;

  // Arg sort
182
  if (sort_column) {
183
184
185
#ifdef PARALLEL_ALGORITHMS
    __gnu_parallel::sort(
#else
186
    std::sort(
187
#endif
188
189
190
191
192
        CooIterator<IdType>(coo_row, coo_col, coo_data),
        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,
        [](const Tuple& a, const Tuple& b) {
          return (std::get<0>(a) != std::get<0>(b)) ?
              (std::get<0>(a) < std::get<0>(b)) : (std::get<1>(a) < std::get<1>(b));
193
194
        });
  } else {
195
196
197
#ifdef PARALLEL_ALGORITHMS
    __gnu_parallel::sort(
#else
198
    std::sort(
199
#endif
200
201
202
203
        CooIterator<IdType>(coo_row, coo_col, coo_data),
        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,
        [](const Tuple& a, const Tuple& b) {
          return std::get<0>(a) < std::get<0>(b);
204
205
206
        });
  }

207
208
209
  coo->row_sorted = true;
  coo->col_sorted = sort_column;
}
210

211
212
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
213

214

215
216
///////////////////////////// COOIsSorted /////////////////////////////

217
template <DGLDeviceType XPU, typename IdType>
218
219
220
221
222
223
224
225
226
227
228
229
230
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
  const int64_t nnz = coo.row->shape[0];
  IdType* row = coo.row.Ptr<IdType>();
  IdType* col = coo.col.Ptr<IdType>();
  bool row_sorted = true;
  bool col_sorted = true;
  for (int64_t i = 1; row_sorted && i < nnz; ++i) {
    row_sorted = (row[i - 1] <= row[i]);
    col_sorted = col_sorted && (row[i - 1] < row[i] || col[i - 1] <= col[i]);
  }
  if (!row_sorted)
    col_sorted = false;
  return {row_sorted, col_sorted};
231
232
}

233
234
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
235
236
237
238

}  // namespace impl
}  // namespace aten
}  // namespace dgl