coo_sort.cc 5.9 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/coo_sort.cc
 * \brief COO sorting
 */
#include <dgl/array.h>
7
8
9
#ifdef PARALLEL_ALGORITHMS
#include <parallel/algorithm>
#endif
10
#include <algorithm>
11
#include <iterator>
12
#include <numeric>
13
#include <tuple>
14
#include <vector>
15
16
17
18
19
20
21
22

namespace {

template <typename IdType>
struct TupleRef {
  TupleRef() = delete;
  TupleRef(const TupleRef& other) = default;
  TupleRef(TupleRef&& other) = default;
23
24
  TupleRef(IdType* const r, IdType* const c, IdType* const d)
      : row(r), col(c), data(d) {}
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

  TupleRef& operator=(const TupleRef& other) {
    *row = *other.row;
    *col = *other.col;
    *data = *other.data;
    return *this;
  }
  TupleRef& operator=(const std::tuple<IdType, IdType, IdType>& val) {
    *row = std::get<0>(val);
    *col = std::get<1>(val);
    *data = std::get<2>(val);
    return *this;
  }

  operator std::tuple<IdType, IdType, IdType>() const {
    return std::make_tuple(*row, *col, *data);
  }

  void Swap(const TupleRef& other) const {
    std::swap(*row, *other.row);
    std::swap(*col, *other.col);
    std::swap(*data, *other.data);
  }

  IdType *row, *col, *data;
};

using std::swap;
template <typename IdType>
void swap(const TupleRef<IdType>& r1, const TupleRef<IdType>& r2) {
  r1.Swap(r2);
}

template <typename IdType>
59
60
61
62
63
struct CooIterator
    : public std::iterator<
          std::random_access_iterator_tag, std::tuple<IdType, IdType, IdType>,
          std::ptrdiff_t, std::tuple<IdType*, IdType*, IdType*>,
          TupleRef<IdType>> {
64
65
66
  CooIterator() = default;
  CooIterator(const CooIterator& other) = default;
  CooIterator(CooIterator&& other) = default;
67
  CooIterator(IdType* r, IdType* c, IdType* d) : row(r), col(c), data(d) {}
68
69
70
71
72

  CooIterator& operator=(const CooIterator& other) = default;
  CooIterator& operator=(CooIterator&& other) = default;
  ~CooIterator() = default;

73
  bool operator==(const CooIterator& other) const { return row == other.row; }
74

75
  bool operator!=(const CooIterator& other) const { return row != other.row; }
76

77
  bool operator<(const CooIterator& other) const { return row < other.row; }
78

79
  bool operator>(const CooIterator& other) const { return row > other.row; }
80

81
  bool operator<=(const CooIterator& other) const { return row <= other.row; }
82

83
  bool operator>=(const CooIterator& other) const { return row >= other.row; }
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

  CooIterator& operator+=(const std::ptrdiff_t& movement) {
    row += movement;
    col += movement;
    data += movement;
    return *this;
  }

  CooIterator& operator-=(const std::ptrdiff_t& movement) {
    row -= movement;
    col -= movement;
    data -= movement;
    return *this;
  }

99
  CooIterator& operator++() { return operator+=(1); }
100

101
  CooIterator& operator--() { return operator-=(1); }
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

  CooIterator operator++(int) {
    CooIterator ret(*this);
    operator++();
    return ret;
  }

  CooIterator operator--(int) {
    CooIterator ret(*this);
    operator--();
    return ret;
  }

  CooIterator operator+(const std::ptrdiff_t& movement) const {
    CooIterator ret(*this);
    ret += movement;
    return ret;
  }

  CooIterator operator-(const std::ptrdiff_t& movement) const {
    CooIterator ret(*this);
    ret -= movement;
    return ret;
  }

  std::ptrdiff_t operator-(const CooIterator& other) const {
    return row - other.row;
  }

  TupleRef<IdType> operator*() const {
    return TupleRef<IdType>(row, col, data);
  }
134
  TupleRef<IdType> operator*() { return TupleRef<IdType>(row, col, data); }
135

136
137
138
139
140
  // required for random access iterators in VS2019
  TupleRef<IdType> operator[](size_t offset) const {
    return TupleRef<IdType>(row + offset, col + offset, data + offset);
  }

141
142
143
144
  IdType *row, *col, *data;
};

}  // namespace
145
146
147
148
149

namespace dgl {
namespace aten {
namespace impl {

150
151
///////////////////////////// COOSort_ /////////////////////////////

152
template <DGLDeviceType XPU, typename IdType>
153
154
155
156
157
158
159
160
161
162
163
void COOSort_(COOMatrix* coo, bool sort_column) {
  const int64_t nnz = coo->row->shape[0];
  IdType* coo_row = coo->row.Ptr<IdType>();
  IdType* coo_col = coo->col.Ptr<IdType>();
  if (!COOHasData(*coo))
    coo->data = aten::Range(0, nnz, coo->row->dtype.bits, coo->row->ctx);
  IdType* coo_data = coo->data.Ptr<IdType>();

  typedef std::tuple<IdType, IdType, IdType> Tuple;

  // Arg sort
164
  if (sort_column) {
165
166
167
#ifdef PARALLEL_ALGORITHMS
    __gnu_parallel::sort(
#else
168
    std::sort(
169
#endif
170
171
172
        CooIterator<IdType>(coo_row, coo_col, coo_data),
        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,
        [](const Tuple& a, const Tuple& b) {
173
174
175
          return (std::get<0>(a) != std::get<0>(b))
                     ? (std::get<0>(a) < std::get<0>(b))
                     : (std::get<1>(a) < std::get<1>(b));
176
177
        });
  } else {
178
179
180
#ifdef PARALLEL_ALGORITHMS
    __gnu_parallel::sort(
#else
181
    std::sort(
182
#endif
183
184
185
186
        CooIterator<IdType>(coo_row, coo_col, coo_data),
        CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,
        [](const Tuple& a, const Tuple& b) {
          return std::get<0>(a) < std::get<0>(b);
187
188
189
        });
  }

190
191
192
  coo->row_sorted = true;
  coo->col_sorted = sort_column;
}
193

194
195
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
196

197
198
///////////////////////////// COOIsSorted /////////////////////////////

199
template <DGLDeviceType XPU, typename IdType>
200
201
202
203
204
205
206
207
208
209
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
  const int64_t nnz = coo.row->shape[0];
  IdType* row = coo.row.Ptr<IdType>();
  IdType* col = coo.col.Ptr<IdType>();
  bool row_sorted = true;
  bool col_sorted = true;
  for (int64_t i = 1; row_sorted && i < nnz; ++i) {
    row_sorted = (row[i - 1] <= row[i]);
    col_sorted = col_sorted && (row[i - 1] < row[i] || col[i - 1] <= col[i]);
  }
210
  if (!row_sorted) col_sorted = false;
211
  return {row_sorted, col_sorted};
212
213
}

214
215
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
216
217
218
219

}  // namespace impl
}  // namespace aten
}  // namespace dgl