/*! * Copyright (c) 2020 by Contributors * \file array/cpu/rowwise_sampling.cc * \brief rowwise sampling */ #include #include #include "./rowwise_pick.h" namespace dgl { namespace aten { namespace impl { namespace { // Equivalent to numpy expression: array[idx[off:off + len]] template inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) { const FloatType* array_data = static_cast(array->data); FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx); FloatType* ret_data = static_cast(ret->data); for (int64_t j = 0; j < len; ++j) { if (idx_data) ret_data[j] = array_data[idx_data[off + j]]; else ret_data[j] = array_data[off + j]; } return ret; } template inline PickFn GetSamplingPickFn( int64_t num_samples, FloatArray prob, bool replace) { PickFn pick_fn = [prob, num_samples, replace] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data, IdxType* out_idx) { FloatArray prob_selected = DoubleSlice(prob, data, off, len); RandomEngine::ThreadLocal()->Choice( num_samples, prob_selected, out_idx, replace); for (int64_t j = 0; j < num_samples; ++j) { out_idx[j] += off; } }; return pick_fn; } template inline RangePickFn GetSamplingRangePickFn( int64_t num_samples, FloatArray prob, bool replace) { RangePickFn pick_fn = [prob, num_samples, replace] (IdxType off, IdxType et_offset, IdxType et_len, const std::vector &et_idx, const IdxType* data, IdxType* out_idx) { const FloatType* p_data = static_cast(prob->data); FloatArray probs = FloatArray::Empty({et_len}, prob->dtype, prob->ctx); FloatType* probs_data = static_cast(probs->data); for (int64_t j = 0; j < et_len; ++j) { if (data) probs_data[j] = p_data[data[off+et_idx[et_offset+j]]]; else probs_data[j] = p_data[off+et_idx[et_offset+j]]; } RandomEngine::ThreadLocal()->Choice( num_samples, probs, out_idx, replace); }; return pick_fn; } template inline PickFn GetSamplingUniformPickFn( int64_t num_samples, bool replace) { PickFn pick_fn = [num_samples, replace] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data, IdxType* out_idx) { RandomEngine::ThreadLocal()->UniformChoice( num_samples, len, out_idx, replace); for (int64_t j = 0; j < num_samples; ++j) { out_idx[j] += off; } }; return pick_fn; } template inline RangePickFn GetSamplingUniformRangePickFn( int64_t num_samples, bool replace) { RangePickFn pick_fn = [num_samples, replace] (IdxType off, IdxType et_offset, IdxType et_len, const std::vector &et_idx, const IdxType* data, IdxType* out_idx) { RandomEngine::ThreadLocal()->UniformChoice( num_samples, et_len, out_idx, replace); }; return pick_fn; } template inline PickFn GetSamplingBiasedPickFn( int64_t num_samples, IdArray split, FloatArray bias, bool replace) { PickFn pick_fn = [num_samples, split, bias, replace] (IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data, IdxType* out_idx) { const IdxType *tag_offset = static_cast(split->data) + rowid * split->shape[1]; RandomEngine::ThreadLocal()->BiasedChoice( num_samples, tag_offset, bias, out_idx, replace); for (int64_t j = 0; j < num_samples; ++j) { out_idx[j] += off; } }; return pick_fn; } } // namespace /////////////////////////////// CSR /////////////////////////////// template COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { CHECK(prob.defined()); auto pick_fn = GetSamplingPickFn(num_samples, prob, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) { CHECK(prob.defined()); auto pick_fn = GetSamplingRangePickFn(num_samples, prob, replace); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); } template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) { auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); template COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, bool replace, bool etype_sorted) { auto pick_fn = GetSamplingUniformRangePickFn(num_samples, replace); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); } template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix, IdArray, IdArray, int64_t, bool, bool); template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix, IdArray, IdArray, int64_t, bool, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset, FloatArray bias, bool replace ) { auto pick_fn = GetSamplingBiasedPickFn( num_samples, tag_offset, bias, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); /////////////////////////////// COO /////////////////////////////// template COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) { CHECK(prob.defined()); auto pick_fn = GetSamplingPickFn(num_samples, prob, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, FloatArray, bool); template COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, FloatArray prob, bool replace, bool etype_sorted) { CHECK(prob.defined()); auto pick_fn = GetSamplingRangePickFn(num_samples, prob, replace); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); } template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, IdArray, int64_t, FloatArray, bool, bool); template COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) { auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn); } template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); template COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes, int64_t num_samples, bool replace, bool etype_sorted) { auto pick_fn = GetSamplingUniformRangePickFn(num_samples, replace); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); } template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix, IdArray, IdArray, int64_t, bool, bool); template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix, IdArray, IdArray, int64_t, bool, bool); } // namespace impl } // namespace aten } // namespace dgl