/*! * Copyright (c) 2020 by Contributors * @file array/cpu/array_sort.cc * @brief Array sort CPU implementation */ #include #ifdef PARALLEL_ALGORITHMS #include #endif #include #include namespace { template struct PairRef { PairRef() = delete; PairRef(const PairRef& other) = default; PairRef(PairRef&& other) = default; PairRef(V1* const r, V2* const c) : row(r), col(c) {} PairRef& operator=(const PairRef& other) { *row = *other.row; *col = *other.col; return *this; } PairRef& operator=(const std::pair& val) { *row = std::get<0>(val); *col = std::get<1>(val); return *this; } operator std::pair() const { return std::make_pair(*row, *col); } void Swap(const PairRef& other) const { std::swap(*row, *other.row); std::swap(*col, *other.col); } V1* row; V2* col; }; using std::swap; template void swap(const PairRef& r1, const PairRef& r2) { r1.Swap(r2); } template struct PairIterator : public std::iterator< std::random_access_iterator_tag, std::pair, std::ptrdiff_t, std::pair, PairRef> { PairIterator() = default; PairIterator(const PairIterator& other) = default; PairIterator(PairIterator&& other) = default; PairIterator(V1* r, V2* c) : row(r), col(c) {} PairIterator& operator=(const PairIterator& other) = default; PairIterator& operator=(PairIterator&& other) = default; ~PairIterator() = default; bool operator==(const PairIterator& other) const { return row == other.row; } bool operator!=(const PairIterator& other) const { return row != other.row; } bool operator<(const PairIterator& other) const { return row < other.row; } bool operator>(const PairIterator& other) const { return row > other.row; } bool operator<=(const PairIterator& other) const { return row <= other.row; } bool operator>=(const PairIterator& other) const { return row >= other.row; } PairIterator& operator+=(const std::ptrdiff_t& movement) { row += movement; col += movement; return *this; } PairIterator& operator-=(const std::ptrdiff_t& movement) { row -= movement; col -= movement; return *this; } PairIterator& operator++() { return operator+=(1); } PairIterator& operator--() { return operator-=(1); } PairIterator operator++(int) { PairIterator ret(*this); operator++(); return ret; } PairIterator operator--(int) { PairIterator ret(*this); operator--(); return ret; } PairIterator operator+(const std::ptrdiff_t& movement) const { PairIterator ret(*this); ret += movement; return ret; } PairIterator operator-(const std::ptrdiff_t& movement) const { PairIterator ret(*this); ret -= movement; return ret; } std::ptrdiff_t operator-(const PairIterator& other) const { return row - other.row; } PairRef operator*() const { return PairRef(row, col); } PairRef operator*() { return PairRef(row, col); } // required for random access iterators in VS2019 PairRef operator[](size_t offset) const { return PairRef(row + offset, col + offset); } V1* row; V2* col; }; } // namespace namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template std::pair Sort(IdArray array, int /* num_bits */) { const int64_t nitem = array->shape[0]; IdArray val = array.Clone(); IdArray idx = aten::Range(0, nitem, 64, array->ctx); IdType* val_data = val.Ptr(); int64_t* idx_data = idx.Ptr(); typedef std::pair Pair; #ifdef PARALLEL_ALGORITHMS __gnu_parallel::sort( #else std::sort( #endif PairIterator(val_data, idx_data), PairIterator(val_data, idx_data) + nitem, [](const Pair& a, const Pair& b) { return std::get<0>(a) < std::get<0>(b); }); return std::make_pair(val, idx); } template std::pair Sort( IdArray, int num_bits); template std::pair Sort( IdArray, int num_bits); } // namespace impl } // namespace aten } // namespace dgl