/*! * Copyright (c) 2020 by Contributors * @file array/cpu/rowwise_topk.cc * @brief rowwise topk */ #include #include #include "./rowwise_pick.h" namespace dgl { namespace aten { namespace impl { namespace { template inline NumPicksFn GetTopkNumPicksFn(int64_t k) { NumPicksFn num_picks_fn = [k] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data) { const int64_t max_num_picks = (k == -1) ? len : k; return std::min(static_cast(max_num_picks), len); }; return num_picks_fn; } template inline PickFn GetTopkPickFn(NDArray weight, bool ascending) { const DType* wdata = static_cast(weight->data); PickFn pick_fn = [ascending, wdata] (IdxType rowid, IdxType off, IdxType len, IdxType num_picks, const IdxType* col, const IdxType* data, IdxType* out_idx) { std::function compare_fn; if (ascending) { if (data) { compare_fn = [wdata, data] (IdxType i, IdxType j) { return wdata[data[i]] < wdata[data[j]]; }; } else { compare_fn = [wdata] (IdxType i, IdxType j) { return wdata[i] < wdata[j]; }; } } else { if (data) { compare_fn = [wdata, data] (IdxType i, IdxType j) { return wdata[data[i]] > wdata[data[j]]; }; } else { compare_fn = [wdata] (IdxType i, IdxType j) { return wdata[i] > wdata[j]; }; } } std::vector idx(len); std::iota(idx.begin(), idx.end(), off); std::sort(idx.begin(), idx.end(), compare_fn); for (int64_t j = 0; j < num_picks; ++j) { out_idx[j] = idx[j]; } }; return pick_fn; } } // namespace template COOMatrix CSRRowWiseTopk( CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { auto num_picks_fn = GetTopkNumPicksFn(k); auto pick_fn = GetTopkPickFn(weight, ascending); return CSRRowWisePick(mat, rows, k, false, pick_fn, num_picks_fn); } template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseTopk( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { auto num_picks_fn = GetTopkNumPicksFn(k); auto pick_fn = GetTopkPickFn(weight, ascending); return COORowWisePick(mat, rows, k, false, pick_fn, num_picks_fn); } template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseTopk( COOMatrix, IdArray, int64_t, NDArray, bool); } // namespace impl } // namespace aten } // namespace dgl