array_sort.cc 4.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/array_sort.cc
 * \brief Array sort CPU implementation
 */
#include <dgl/array.h>
#ifdef PARALLEL_ALGORITHMS
#include <parallel/algorithm>
#endif
#include <algorithm>
#include <iterator>

namespace {

template <typename V1, typename V2>
struct PairRef {
  PairRef() = delete;
  PairRef(const PairRef& other) = default;
  PairRef(PairRef&& other) = default;
20
  PairRef(V1* const r, V2* const c) : row(r), col(c) {}
21
22
23
24
25
26
27
28
29
30
31
32

  PairRef& operator=(const PairRef& other) {
    *row = *other.row;
    *col = *other.col;
    return *this;
  }
  PairRef& operator=(const std::pair<V1, V2>& val) {
    *row = std::get<0>(val);
    *col = std::get<1>(val);
    return *this;
  }

33
  operator std::pair<V1, V2>() const { return std::make_pair(*row, *col); }
34
35
36
37
38
39

  void Swap(const PairRef& other) const {
    std::swap(*row, *other.row);
    std::swap(*col, *other.col);
  }

40
41
  V1* row;
  V2* col;
42
43
44
45
46
47
48
49
50
};

using std::swap;
template <typename V1, typename V2>
void swap(const PairRef<V1, V2>& r1, const PairRef<V1, V2>& r2) {
  r1.Swap(r2);
}

template <typename V1, typename V2>
51
52
53
54
struct PairIterator
    : public std::iterator<
          std::random_access_iterator_tag, std::pair<V1, V2>, std::ptrdiff_t,
          std::pair<V1*, V2*>, PairRef<V1, V2>> {
55
56
57
  PairIterator() = default;
  PairIterator(const PairIterator& other) = default;
  PairIterator(PairIterator&& other) = default;
58
  PairIterator(V1* r, V2* c) : row(r), col(c) {}
59
60
61
62
63

  PairIterator& operator=(const PairIterator& other) = default;
  PairIterator& operator=(PairIterator&& other) = default;
  ~PairIterator() = default;

64
  bool operator==(const PairIterator& other) const { return row == other.row; }
65

66
  bool operator!=(const PairIterator& other) const { return row != other.row; }
67

68
  bool operator<(const PairIterator& other) const { return row < other.row; }
69

70
  bool operator>(const PairIterator& other) const { return row > other.row; }
71

72
  bool operator<=(const PairIterator& other) const { return row <= other.row; }
73

74
  bool operator>=(const PairIterator& other) const { return row >= other.row; }
75
76
77
78
79
80
81
82
83
84
85
86
87

  PairIterator& operator+=(const std::ptrdiff_t& movement) {
    row += movement;
    col += movement;
    return *this;
  }

  PairIterator& operator-=(const std::ptrdiff_t& movement) {
    row -= movement;
    col -= movement;
    return *this;
  }

88
  PairIterator& operator++() { return operator+=(1); }
89

90
  PairIterator& operator--() { return operator-=(1); }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

  PairIterator operator++(int) {
    PairIterator ret(*this);
    operator++();
    return ret;
  }

  PairIterator operator--(int) {
    PairIterator ret(*this);
    operator--();
    return ret;
  }

  PairIterator operator+(const std::ptrdiff_t& movement) const {
    PairIterator ret(*this);
    ret += movement;
    return ret;
  }

  PairIterator operator-(const std::ptrdiff_t& movement) const {
    PairIterator ret(*this);
    ret -= movement;
    return ret;
  }

  std::ptrdiff_t operator-(const PairIterator& other) const {
    return row - other.row;
  }

120
121
  PairRef<V1, V2> operator*() const { return PairRef<V1, V2>(row, col); }
  PairRef<V1, V2> operator*() { return PairRef<V1, V2>(row, col); }
122

123
124
125
126
127
  // required for random access iterators in VS2019
  PairRef<V1, V2> operator[](size_t offset) const {
    return PairRef<V1, V2>(row + offset, col + offset);
  }

128
129
  V1* row;
  V2* col;
130
131
132
133
134
135
136
137
138
};

}  // namespace

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

139
template <DGLDeviceType XPU, typename IdType>
140
std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
141
142
143
144
145
146
147
148
149
150
151
152
153
  const int64_t nitem = array->shape[0];
  IdArray val = array.Clone();
  IdArray idx = aten::Range(0, nitem, 64, array->ctx);
  IdType* val_data = val.Ptr<IdType>();
  int64_t* idx_data = idx.Ptr<int64_t>();
  typedef std::pair<IdType, int64_t> Pair;
#ifdef PARALLEL_ALGORITHMS
  __gnu_parallel::sort(
#else
  std::sort(
#endif
      PairIterator<IdType, int64_t>(val_data, idx_data),
      PairIterator<IdType, int64_t>(val_data, idx_data) + nitem,
154
      [](const Pair& a, const Pair& b) {
155
156
157
158
159
        return std::get<0>(a) < std::get<0>(b);
      });
  return std::make_pair(val, idx);
}

160
161
162
163
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(
    IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(
    IdArray, int num_bits);
164
165
166
167

}  // namespace impl
}  // namespace aten
}  // namespace dgl