/*! * Copyright (c) 2020 by Contributors * \file array/cpu/csr_sort.cc * \brief CSR sorting */ #include #include #include #include namespace dgl { namespace aten { namespace impl { ///////////////////////////// CSRIsSorted ///////////////////////////// template bool CSRIsSorted(CSRMatrix csr) { const IdType* indptr = csr.indptr.Ptr(); const IdType* indices = csr.indices.Ptr(); bool ret = true; for (int64_t row = 0; row < csr.num_rows; ++row) { if (!ret) continue; for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) { if (indices[i - 1] > indices[i]) { ret = false; break; } } } return ret; } template bool CSRIsSorted(CSRMatrix csr); template bool CSRIsSorted(CSRMatrix csr); ///////////////////////////// CSRSort ///////////////////////////// template void CSRSort_(CSRMatrix* csr) { typedef std::pair ShufflePair; const int64_t num_rows = csr->num_rows; const int64_t nnz = csr->indices->shape[0]; const IdType* indptr_data = static_cast(csr->indptr->data); IdType* indices_data = static_cast(csr->indices->data); if (!CSRHasData(*csr)) { csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx); } IdType* eid_data = static_cast(csr->data->data); #pragma omp parallel { std::vector reorder_vec; #pragma omp for for (int64_t row = 0; row < num_rows; row++) { const int64_t num_cols = indptr_data[row + 1] - indptr_data[row]; IdType *col = indices_data + indptr_data[row]; IdType *eid = eid_data + indptr_data[row]; reorder_vec.resize(num_cols); for (int64_t i = 0; i < num_cols; i++) { reorder_vec[i].first = col[i]; reorder_vec[i].second = eid[i]; } std::sort(reorder_vec.begin(), reorder_vec.end(), [](const ShufflePair &e1, const ShufflePair &e2) { return e1.first < e2.first; }); for (int64_t i = 0; i < num_cols; i++) { col[i] = reorder_vec[i].first; eid[i] = reorder_vec[i].second; } } } csr->sorted = true; } template void CSRSort_(CSRMatrix* csr); template void CSRSort_(CSRMatrix* csr); template std::pair CSRSortByTag( const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) { const auto indptr_data = static_cast(csr.indptr->data); const auto indices_data = static_cast(csr.indices->data); const auto eid_array = aten::CSRHasData(csr) ? csr.data : aten::Range(0, csr.indices->shape[0], csr.indptr->dtype.bits, csr.indptr->ctx); const auto eid_data = static_cast(csr.data->data); const auto tag_data = static_cast(tag_array->data); const int64_t num_rows = csr.num_rows; NDArray tag_pos = NDArray::Empty({csr.num_rows, num_tags + 1}, csr.indptr->dtype, csr.indptr->ctx); auto tag_pos_data = static_cast(tag_pos->data); std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0); aten::CSRMatrix output(csr.num_rows, csr.num_cols, csr.indptr.Clone(), csr.indices.Clone(), eid_array.Clone(), csr.sorted); auto out_indices_data = static_cast(output.indices->data); auto out_eid_data = static_cast(output.data->data); #pragma omp parallel for for (IdType src = 0 ; src < num_rows ; ++src) { const IdType start = indptr_data[src]; const IdType end = indptr_data[src + 1]; auto tag_pos_row = tag_pos_data + src * (num_tags + 1); std::vector pointer(num_tags, 0); for (IdType ptr = start ; ptr < end ; ++ptr) { const IdType dst = indices_data[ptr]; const TagType tag = tag_data[dst]; CHECK_LT(tag, num_tags); ++tag_pos_row[tag + 1]; } // count for (TagType tag = 1 ; tag <= num_tags; ++tag) { tag_pos_row[tag] += tag_pos_row[tag - 1]; } // cumulate for (IdType ptr = start ; ptr < end ; ++ptr) { const IdType dst = indices_data[ptr]; const IdType eid = eid_data[ptr]; const TagType tag = tag_data[dst]; const IdType offset = tag_pos_row[tag] + pointer[tag]; CHECK_LT(offset, tag_pos_row[tag + 1]); ++pointer[tag]; out_indices_data[start + offset] = dst; out_eid_data[start + offset] = eid; } } output.sorted = false; return std::make_pair(output, tag_pos); } template std::pair CSRSortByTag( const CSRMatrix &csr, const IdArray tag, int64_t num_tags); template std::pair CSRSortByTag( const CSRMatrix &csr, const IdArray tag, int64_t num_tags); template std::pair CSRSortByTag( const CSRMatrix &csr, const IdArray tag, int64_t num_tags); template std::pair CSRSortByTag( const CSRMatrix &csr, const IdArray tag, int64_t num_tags); } // namespace impl } // namespace aten } // namespace dgl