/*! * Copyright (c) 2020 by Contributors * \file array/cpu/rowwise_sampling.cc * \brief rowwise sampling */ #include #include #include "./rowwise_pick.h" namespace dgl { namespace aten { namespace impl { namespace { // Equivalent to numpy expression: array[idx[off:off + len]] template inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) { const FloatType* array_data = static_cast(array->data); FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx); FloatType* ret_data = static_cast(ret->data); for (int64_t j = 0; j < len; ++j) { if (idx_data) ret_data[j] = array_data[idx_data[off + j]]; else ret_data[j] = array_data[off + j]; } return ret; } template inline PickFn GetSamplingPickFn( int64_t num_samples, FloatArray prob, bool replace) { PickFn pick_fn = [prob, num_samples, replace] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data, IdxType* out_idx) { // TODO(minjie): If efficiency is a problem, consider avoid creating // explicit NDArrays by directly manipulating buffers. FloatArray prob_selected = DoubleSlice(prob, data, off, len); IdArray sampled = RandomEngine::ThreadLocal()->Choice( num_samples, prob_selected, replace); const IdxType* sampled_data = static_cast(sampled->data); for (int64_t j = 0; j < num_samples; ++j) { out_idx[j] = off + sampled_data[j]; } }; return pick_fn; } template inline PickFn GetSamplingUniformPickFn( int64_t num_samples, bool replace) { PickFn pick_fn = [num_samples, replace] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data, IdxType* out_idx) { // TODO(minjie): If efficiency is a problem, consider avoid creating // explicit NDArrays by directly manipulating buffers. IdArray sampled = RandomEngine::ThreadLocal()->UniformChoice( num_samples, len, replace); const IdxType* sampled_data = static_cast(sampled->data); for (int64_t j = 0; j < num_samples; ++j) { out_idx[j] = off + sampled_data[j]; } }; return pick_fn; } } // namespace /////////////////////////////// CSR /////////////////////////////// template COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { auto pick_fn = GetSamplingPickFn(num_samples, prob, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) { auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); /////////////////////////////// COO /////////////////////////////// template COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { auto pick_fn = GetSamplingPickFn(num_samples, prob, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) { auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); } // namespace impl } // namespace aten } // namespace dgl