array_repeat.cc 1.65 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/array_repeat.cc
 * @brief Array repeat CPU implementation
5
6
 */
#include <dgl/array.h>
7

8
9
10
11
12
13
14
#include <algorithm>

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

15
template <DGLDeviceType XPU, typename DType, typename IdType>
16
NDArray Repeat(NDArray array, IdArray repeats) {
17
18
  CHECK(array->shape[0] == repeats->shape[0])
      << "shape of array and repeats mismatch";
19
20
21
22
23
24

  const int64_t len = array->shape[0];
  const DType *array_data = static_cast<DType *>(array->data);
  const IdType *repeats_data = static_cast<IdType *>(repeats->data);

  IdType num_elements = 0;
25
  for (int64_t i = 0; i < len; ++i) num_elements += repeats_data[i];
26
27
28
29
30

  NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx);
  DType *result_data = static_cast<DType *>(result->data);
  IdType curr = 0;
  for (int64_t i = 0; i < len; ++i) {
31
32
33
    std::fill(
        result_data + curr, result_data + curr + repeats_data[i],
        array_data[i]);
34
35
36
37
38
39
    curr += repeats_data[i];
  }

  return result;
}

40
41
42
43
44
45
46
47
template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
48
49
50
51

};  // namespace impl
};  // namespace aten
};  // namespace dgl