Unverified Commit aa11aaa4 authored by David Min's avatar David Min Committed by GitHub
Browse files

[Peformance] Parallelize CSRSliceRows() (#3409)



* parallelize CSRRowSlice()

* use parallel_for for the second loop
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent ff94ee80
...@@ -366,27 +366,70 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -366,27 +366,70 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* rows_data = static_cast<IdType*>(rows->data); const IdType* rows_data = static_cast<IdType*>(rows->data);
int64_t nnz = 0; int64_t nnz = 0;
for (int64_t i = 0; i < len; ++i) {
IdType vid = rows_data[i];
nnz += impl::CSRGetRowNNZ<XPU, IdType>(csr, vid);
}
CSRMatrix ret; CSRMatrix ret;
ret.num_rows = len; ret.num_rows = len;
ret.num_cols = csr.num_cols; ret.num_cols = csr.num_cols;
ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx); ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);
IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
ret_indptr_data[0] = 0;
std::vector<IdType> sums;
// Perform two-round parallel prefix sum using OpenMP
#pragma omp parallel
{
int64_t tid = omp_get_thread_num();
int64_t num_threads = omp_get_num_threads();
#pragma omp single
{
sums.resize(num_threads + 1);
sums[0] = 0;
}
int64_t sum = 0;
// First round of parallel prefix sum. All threads perform local prefix sums.
#pragma omp for schedule(static) nowait
for (int64_t i = 0; i < len; ++i) {
int64_t rid = rows_data[i];
sum += indptr_data[rid + 1] - indptr_data[rid];
ret_indptr_data[i + 1] = sum;
}
sums[tid + 1] = sum;
#pragma omp barrier
#pragma omp single
{
for (int64_t i = 1; i < num_threads; ++i)
sums[i] += sums[i - 1];
}
int64_t offset = sums[tid];
// Second round of parallel prefix sum. Update the local prefix sums.
#pragma omp for schedule(static)
for (int64_t i = 0; i < len; ++i)
ret_indptr_data[i + 1] += offset;
}
// After the prefix sum, the last element of ret_indptr_data holds the
// sum of all elements
nnz = ret_indptr_data[len];
ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx); ret.data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
ret.sorted = csr.sorted; ret.sorted = csr.sorted;
IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data); IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);
IdType* ret_data = static_cast<IdType*>(ret.data->data); IdType* ret_data = static_cast<IdType*>(ret.data->data);
ret_indptr_data[0] = 0;
for (int64_t i = 0; i < len; ++i) { parallel_for(0, len, [=](int64_t b, int64_t e) {
for (auto i = b; i < e; ++i) {
const IdType rid = rows_data[i]; const IdType rid = rows_data[i];
// note: zero is allowed // note: zero is allowed
ret_indptr_data[i + 1] = ret_indptr_data[i] + indptr_data[rid + 1] - indptr_data[rid];
std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1], std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
ret_indices_data + ret_indptr_data[i]); ret_indices_data + ret_indptr_data[i]);
if (data) if (data)
...@@ -396,6 +439,7 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -396,6 +439,7 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
std::iota(ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1], std::iota(ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],
indptr_data[rid]); indptr_data[rid]);
} }
});
return ret; return ret;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment