rowwise_topk.cc 4.49 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cpu/rowwise_topk.cc
 * @brief rowwise topk
6
7
 */
#include <algorithm>
8
9
#include <numeric>

sangwzh's avatar
sangwzh committed
10
#include "rowwise_pick.h"
11
12
13
14
15
16

namespace dgl {
namespace aten {
namespace impl {
namespace {

17
18
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
19
20
21
22
23
24
  NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (k == -1) ? len : k;
    return std::min(static_cast<IdxType>(max_num_picks), len);
  };
25
26
27
  return num_picks_fn;
}

28
template <typename IdxType, typename DType>
29
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
30
  const DType* wdata = static_cast<DType*>(weight->data);
31
32
33
34
35
36
37
38
39
40
  PickFn<IdxType> pick_fn = [ascending, wdata](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    std::function<bool(IdxType, IdxType)> compare_fn;
    if (ascending) {
      if (data) {
        compare_fn = [wdata, data](IdxType i, IdxType j) {
          return wdata[data[i]] < wdata[data[j]];
        };
41
      } else {
42
43
44
        compare_fn = [wdata](IdxType i, IdxType j) {
          return wdata[i] < wdata[j];
        };
45
      }
46
47
48
49
50
51
52
53
54
    } else {
      if (data) {
        compare_fn = [wdata, data](IdxType i, IdxType j) {
          return wdata[data[i]] > wdata[data[j]];
        };
      } else {
        compare_fn = [wdata](IdxType i, IdxType j) {
          return wdata[i] > wdata[j];
        };
55
      }
56
57
58
59
60
61
62
63
64
    }

    std::vector<IdxType> idx(len);
    std::iota(idx.begin(), idx.end(), off);
    std::sort(idx.begin(), idx.end(), compare_fn);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] = idx[j];
    }
  };
65
66
67
68
69
70

  return pick_fn;
}

}  // namespace

71
template <DGLDeviceType XPU, typename IdxType, typename DType>
72
COOMatrix CSRRowWiseTopk(
73
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
74
75
76
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
77
78
}

79
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
80
    CSRMatrix, IdArray, int64_t, NDArray, bool);
81
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
82
    CSRMatrix, IdArray, int64_t, NDArray, bool);
83
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
84
    CSRMatrix, IdArray, int64_t, NDArray, bool);
85
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
86
    CSRMatrix, IdArray, int64_t, NDArray, bool);
87
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
88
    CSRMatrix, IdArray, int64_t, NDArray, bool);
89
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
90
    CSRMatrix, IdArray, int64_t, NDArray, bool);
91
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
92
    CSRMatrix, IdArray, int64_t, NDArray, bool);
93
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
94
    CSRMatrix, IdArray, int64_t, NDArray, bool);
95

96
template <DGLDeviceType XPU, typename IdxType, typename DType>
97
COOMatrix COORowWiseTopk(
98
    COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
99
100
101
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
102
103
}

104
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
105
    COOMatrix, IdArray, int64_t, NDArray, bool);
106
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
107
    COOMatrix, IdArray, int64_t, NDArray, bool);
108
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
109
    COOMatrix, IdArray, int64_t, NDArray, bool);
110
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
111
    COOMatrix, IdArray, int64_t, NDArray, bool);
112
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
113
    COOMatrix, IdArray, int64_t, NDArray, bool);
114
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
115
    COOMatrix, IdArray, int64_t, NDArray, bool);
116
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
117
    COOMatrix, IdArray, int64_t, NDArray, bool);
118
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
119
    COOMatrix, IdArray, int64_t, NDArray, bool);
120
121
122
123

}  // namespace impl
}  // namespace aten
}  // namespace dgl