/*! * Copyright (c) 2019 by Contributors * \file array/cpu/array_index_select.cc * \brief Array index select CPU implementation */ #include #include #include #include namespace dgl { using runtime::NDArray; using runtime::parallel_for; namespace aten { namespace impl { template std::pair ConcatSlices(NDArray array, IdArray lengths) { const int64_t rows = lengths->shape[0]; const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]); const int64_t stride = (array->ndim == 1 ? 0 : cols); const DType *array_data = static_cast(array->data); const IdType *length_data = static_cast(lengths->data); IdArray offsets = NewIdArray(rows, array->ctx, sizeof(IdType) * 8); IdType *offsets_data = static_cast(offsets->data); for (int64_t i = 0; i < rows; ++i) offsets_data[i] = (i == 0 ? 0 : length_data[i - 1] + offsets_data[i - 1]); const int64_t total_length = offsets_data[rows - 1] + length_data[rows - 1]; NDArray concat = NDArray::Empty({total_length}, array->dtype, array->ctx); DType *concat_data = static_cast(concat->data); parallel_for(0, rows, [=](size_t b, size_t e) { for (auto i = b; i < e; ++i) { for (int64_t j = 0; j < length_data[i]; ++j) concat_data[offsets_data[i] + j] = array_data[i * stride + j]; } }); return std::make_pair(concat, offsets); } template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::pair ConcatSlices(NDArray, IdArray); template std::tuple Pack(NDArray array, DType pad_value) { CHECK_NDIM(array, 2, "array"); const DType *array_data = static_cast(array->data); const int64_t rows = array->shape[0]; const int64_t cols = array->shape[1]; IdArray length = NewIdArray(rows, array->ctx); int64_t *length_data = static_cast(length->data); parallel_for(0, rows, [=](size_t b, size_t e) { for (auto i = b; i < e; ++i) { int64_t j; for (j = 0; j < cols; ++j) { const DType val = array_data[i * cols + j]; if (val == pad_value) break; } length_data[i] = j; } }); auto ret = ConcatSlices(array, length); return std::make_tuple(ret.first, length, ret.second); } template std::tuple Pack(NDArray, int32_t); template std::tuple Pack(NDArray, int64_t); template std::tuple Pack(NDArray, float); template std::tuple Pack(NDArray, double); } // namespace impl } // namespace aten } // namespace dgl