"git@developer.sourcefind.cn:change/sglang.git" did not exist on "f2388f6b9557e348732336a7a9afa14f167de33b"
rowwise_topk.cc 4.36 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cpu/rowwise_topk.cc
 * @brief rowwise topk
5
6
7
8
9
10
11
12
13
14
 */
#include <numeric>
#include <algorithm>
#include "./rowwise_pick.h"

namespace dgl {
namespace aten {
namespace impl {
namespace {

15
16
17
18
19
20
21
22
23
24
25
template <typename IdxType>
inline NumPicksFn<IdxType> GetTopkNumPicksFn(int64_t k) {
  NumPicksFn<IdxType> num_picks_fn = [k]
    (IdxType rowid, IdxType off, IdxType len,
     const IdxType* col, const IdxType* data) {
      const int64_t max_num_picks = (k == -1) ? len : k;
      return std::min(static_cast<IdxType>(max_num_picks), len);
    };
  return num_picks_fn;
}

26
template <typename IdxType, typename DType>
27
inline PickFn<IdxType> GetTopkPickFn(NDArray weight, bool ascending) {
28
  const DType* wdata = static_cast<DType*>(weight->data);
29
30
  PickFn<IdxType> pick_fn = [ascending, wdata]
    (IdxType rowid, IdxType off, IdxType len, IdxType num_picks,
31
32
33
34
35
36
37
38
39
     const IdxType* col, const IdxType* data,
     IdxType* out_idx) {
      std::function<bool(IdxType, IdxType)> compare_fn;
      if (ascending) {
        if (data) {
          compare_fn = [wdata, data] (IdxType i, IdxType j) {
              return wdata[data[i]] < wdata[data[j]];
            };
        } else {
Jinjing Zhou's avatar
Jinjing Zhou committed
40
          compare_fn = [wdata] (IdxType i, IdxType j) {
41
42
43
44
45
46
47
48
49
              return wdata[i] < wdata[j];
            };
        }
      } else {
        if (data) {
          compare_fn = [wdata, data] (IdxType i, IdxType j) {
              return wdata[data[i]] > wdata[data[j]];
            };
        } else {
Jinjing Zhou's avatar
Jinjing Zhou committed
50
          compare_fn = [wdata] (IdxType i, IdxType j) {
51
52
53
54
55
56
57
58
              return wdata[i] > wdata[j];
            };
        }
      }

      std::vector<IdxType> idx(len);
      std::iota(idx.begin(), idx.end(), off);
      std::sort(idx.begin(), idx.end(), compare_fn);
59
      for (int64_t j = 0; j < num_picks; ++j) {
60
61
62
63
64
65
66
67
68
        out_idx[j] = idx[j];
      }
    };

  return pick_fn;
}

}  // namespace

69
template <DGLDeviceType XPU, typename IdxType, typename DType>
70
COOMatrix CSRRowWiseTopk(
71
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
72
73
74
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
75
76
}

77
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
78
    CSRMatrix, IdArray, int64_t, NDArray, bool);
79
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
80
    CSRMatrix, IdArray, int64_t, NDArray, bool);
81
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
82
    CSRMatrix, IdArray, int64_t, NDArray, bool);
83
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
84
    CSRMatrix, IdArray, int64_t, NDArray, bool);
85
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
86
    CSRMatrix, IdArray, int64_t, NDArray, bool);
87
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
88
    CSRMatrix, IdArray, int64_t, NDArray, bool);
89
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
90
    CSRMatrix, IdArray, int64_t, NDArray, bool);
91
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
92
    CSRMatrix, IdArray, int64_t, NDArray, bool);
93

94
template <DGLDeviceType XPU, typename IdxType, typename DType>
95
COOMatrix COORowWiseTopk(
96
    COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
97
98
99
  auto num_picks_fn = GetTopkNumPicksFn<IdxType>(k);
  auto pick_fn = GetTopkPickFn<IdxType, DType>(weight, ascending);
  return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn);
100
101
}

102
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
103
    COOMatrix, IdArray, int64_t, NDArray, bool);
104
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
105
    COOMatrix, IdArray, int64_t, NDArray, bool);
106
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
107
    COOMatrix, IdArray, int64_t, NDArray, bool);
108
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
109
    COOMatrix, IdArray, int64_t, NDArray, bool);
110
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
111
    COOMatrix, IdArray, int64_t, NDArray, bool);
112
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
113
    COOMatrix, IdArray, int64_t, NDArray, bool);
114
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
115
    COOMatrix, IdArray, int64_t, NDArray, bool);
116
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
117
    COOMatrix, IdArray, int64_t, NDArray, bool);
118
119
120
121

}  // namespace impl
}  // namespace aten
}  // namespace dgl