rowwise_topk.cc 4.43 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_topk.cc
 * @brief rowwise topk
5
6
 */
#include <algorithm>
7
8
#include <numeric>

9
10
11
12
13
14
15
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {

16
17
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
18
19
20
21
22
23
  NumPicksFn<IdxType> num_picks_fn = [k](IdxType rowid, IdxType off,
                                         IdxType len, const IdxType* col,
                                         const IdxType* data) {
    const int64_t max_num_picks = (k == -1) ? len : k;
    return std::min(static_cast<IdxType>(max_num_picks), len);
  };
24
25
26
  return num_picks_fn;
}

27
template <typename IdxType, typename DType>
28
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
29
  const DType* wdata = static_cast<DType*>(weight->data);
30
31
32
33
34
35
36
37
38
39
  PickFn<IdxType> pick_fn = [ascending, wdata](
                                IdxType rowid, IdxType off, IdxType len,
                                IdxType num_picks, const IdxType* col,
                                const IdxType* data, IdxType* out_idx) {
    std::function<bool(IdxType, IdxType)> compare_fn;
    if (ascending) {
      if (data) {
        compare_fn = [wdata, data](IdxType i, IdxType j) {
          return wdata[data[i]] < wdata[data[j]];
        };
40
      } else {
41
42
43
        compare_fn = [wdata](IdxType i, IdxType j) {
          return wdata[i] < wdata[j];
        };
44
      }
45
46
47
48
49
50
51
52
53
    } else {
      if (data) {
        compare_fn = [wdata, data](IdxType i, IdxType j) {
          return wdata[data[i]] > wdata[data[j]];
        };
      } else {
        compare_fn = [wdata](IdxType i, IdxType j) {
          return wdata[i] > wdata[j];
        };
54
      }
55
56
57
58
59
60
61
62
63
    }

    std::vector<IdxType> idx(len);
    std::iota(idx.begin(), idx.end(), off);
    std::sort(idx.begin(), idx.end(), compare_fn);
    for (int64_t j = 0; j < num_picks; ++j) {
      out_idx[j] = idx[j];
    }
  };
64
65
66
67
68
69

  return pick_fn;
}

}  // namespace

70
template <DGLDeviceType XPU, typename IdxType, typename DType>
71
COOMatrix CSRRowWiseTopk(
72
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
73
74
75
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
76
77
}

78
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
79
    CSRMatrix, IdArray, int64_t, NDArray, bool);
80
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
81
    CSRMatrix, IdArray, int64_t, NDArray, bool);
82
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
83
    CSRMatrix, IdArray, int64_t, NDArray, bool);
84
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
85
    CSRMatrix, IdArray, int64_t, NDArray, bool);
86
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
87
    CSRMatrix, IdArray, int64_t, NDArray, bool);
88
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
89
    CSRMatrix, IdArray, int64_t, NDArray, bool);
90
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
91
    CSRMatrix, IdArray, int64_t, NDArray, bool);
92
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
93
    CSRMatrix, IdArray, int64_t, NDArray, bool);
94

95
template <DGLDeviceType XPU, typename IdxType, typename DType>
96
COOMatrix COORowWiseTopk(
97
    COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
98
99
100
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
101
102
}

103
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
104
    COOMatrix, IdArray, int64_t, NDArray, bool);
105
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
106
    COOMatrix, IdArray, int64_t, NDArray, bool);
107
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
108
    COOMatrix, IdArray, int64_t, NDArray, bool);
109
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
110
    COOMatrix, IdArray, int64_t, NDArray, bool);
111
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
112
    COOMatrix, IdArray, int64_t, NDArray, bool);
113
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
114
    COOMatrix, IdArray, int64_t, NDArray, bool);
115
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
116
    COOMatrix, IdArray, int64_t, NDArray, bool);
117
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
118
    COOMatrix, IdArray, int64_t, NDArray, bool);
119
120
121
122

}  // namespace impl
}  // namespace aten
}  // namespace dgl