Unverified Commit 00c27cb2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Kernel] Parallel find edges (#4878)

* use runtime parallel_for

* grain size

* Update array_index_select.cc
parent 3132da28
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief Array index select CPU implementation * @brief Array index select CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h>
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -22,10 +23,16 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -22,10 +23,16 @@ NDArray IndexSelect(NDArray array, IdArray index) {
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx); NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < len; ++i) { runtime::parallel_for(
CHECK_LT(idx_data[i], arr_len) << "Index out of range."; 0,
ret_data[i] = array_data[idx_data[i]]; len,
} 1000, // Thread scheduling overhead is bigger with tiny grain size.
[idx_data, arr_len, ret_data, array_data] (size_t begin, size_t end) {
for (size_t i = begin; i < end; ++i) {
CHECK_LT(idx_data[i], arr_len) << "Index out of range.";
ret_data[i] = array_data[idx_data[i]];
}
});
return ret; return ret;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment