/** * Copyright (c) 2020 by Contributors * @file array/cpu/array_repeat.cc * @brief Array repeat CPU implementation */ #include #include namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template NDArray Repeat(NDArray array, IdArray repeats) { CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch"; const int64_t len = array->shape[0]; const DType *array_data = static_cast(array->data); const IdType *repeats_data = static_cast(repeats->data); IdType num_elements = 0; for (int64_t i = 0; i < len; ++i) num_elements += repeats_data[i]; NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx); DType *result_data = static_cast(result->data); IdType curr = 0; for (int64_t i = 0; i < len; ++i) { std::fill( result_data + curr, result_data + curr + repeats_data[i], array_data[i]); curr += repeats_data[i]; } return result; } template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); template NDArray Repeat(NDArray, IdArray); }; // namespace impl }; // namespace aten }; // namespace dgl