/*! * Copyright (c) 2019 by Contributors * @file array/cpu/array_scatter.cc * @brief Array scatter CPU implementation */ #include #include namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template NDArray Scatter(NDArray array, IdArray indices) { NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx); const DType *array_data = static_cast(array->data); const IdType *indices_data = static_cast(indices->data); DType *result_data = static_cast(result->data); for (int64_t i = 0; i < indices->shape[0]; ++i) result_data[indices_data[i]] = array_data[i]; return result; } template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template NDArray Scatter(NDArray, IdArray); template void Scatter_(IdArray index, NDArray value, NDArray out) { const int64_t len = index->shape[0]; const IdType *idx = index.Ptr(); const DType *val = value.Ptr(); DType *outd = out.Ptr(); runtime::parallel_for(0, len, [&](size_t b, size_t e) { for (auto i = b; i < e; ++i) { outd[idx[i]] = val[i]; } }); } template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); }; // namespace impl }; // namespace aten }; // namespace dgl