"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "98457580c061a182e2f99cd9fb80f73db386cc59"
array_repeat.cc 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cpu/array_repeat.cc
 * \brief Array repeat CPU implementation
 */
#include <dgl/array.h>
#include <algorithm>

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

14
template <DGLDeviceType XPU, typename DType, typename IdType>
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
NDArray Repeat(NDArray array, IdArray repeats) {
  CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch";

  const int64_t len = array->shape[0];
  const DType *array_data = static_cast<DType *>(array->data);
  const IdType *repeats_data = static_cast<IdType *>(repeats->data);

  IdType num_elements = 0;
  for (int64_t i = 0; i < len; ++i)
    num_elements += repeats_data[i];

  NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx);
  DType *result_data = static_cast<DType *>(result->data);
  IdType curr = 0;
  for (int64_t i = 0; i < len; ++i) {
    std::fill(result_data + curr, result_data + curr + repeats_data[i], array_data[i]);
    curr += repeats_data[i];
  }

  return result;
}

37
38
39
40
41
42
43
44
template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
45
46
47
48

};  // namespace impl
};  // namespace aten
};  // namespace dgl