/** * Copyright (c) 2020 by Contributors * @file array/cpu/rowwise_sampling.cc * @brief rowwise sampling */ #include #include #include "./rowwise_pick.h" namespace dgl { namespace aten { namespace impl { namespace { // Equivalent to numpy expression: array[idx[off:off + len]] template inline FloatArray DoubleSlice( FloatArray array, const IdxType* idx_data, IdxType off, IdxType len) { const FloatType* array_data = static_cast(array->data); FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx); FloatType* ret_data = static_cast(ret->data); for (int64_t j = 0; j < len; ++j) { if (idx_data) ret_data[j] = array_data[idx_data[off + j]]; else ret_data[j] = array_data[off + j]; } return ret; } template inline NumPicksFn GetSamplingNumPicksFn( int64_t num_samples, NDArray prob_or_mask, bool replace) { NumPicksFn num_picks_fn = [prob_or_mask, num_samples, replace]( IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data) { const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; const DType* prob_or_mask_data = prob_or_mask.Ptr(); IdxType nnz = 0; for (IdxType i = off; i < off + len; ++i) { const IdxType eid = data ? data[i] : i; if (prob_or_mask_data[eid] > 0) { ++nnz; } } if (replace) { return static_cast(nnz == 0 ? 0 : max_num_picks); } else { return std::min(static_cast(max_num_picks), nnz); } }; return num_picks_fn; } template inline PickFn GetSamplingPickFn( int64_t num_samples, NDArray prob_or_mask, bool replace) { PickFn pick_fn = [prob_or_mask, num_samples, replace]( IdxType rowid, IdxType off, IdxType len, IdxType num_picks, const IdxType* col, const IdxType* data, IdxType* out_idx) { NDArray prob_or_mask_selected = DoubleSlice(prob_or_mask, data, off, len); RandomEngine::ThreadLocal()->Choice( num_picks, prob_or_mask_selected, out_idx, replace); for (int64_t j = 0; j < num_picks; ++j) { out_idx[j] += off; } }; return pick_fn; } template inline EtypeRangePickFn GetSamplingRangePickFn( const std::vector& num_samples, const std::vector& prob, bool replace) { EtypeRangePickFn pick_fn = [prob, num_samples, replace]( IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, const std::vector& et_idx, const std::vector& et_eid, const IdxType* eid, IdxType* out_idx) { const FloatArray& p = prob[cur_et]; const FloatType* p_data = IsNullArray(p) ? nullptr : p.Ptr(); FloatArray probs = FloatArray::Empty({et_len}, p->dtype, p->ctx); FloatType* probs_data = probs.Ptr(); for (int64_t j = 0; j < et_len; ++j) { const IdxType cur_eid = et_eid[et_idx[et_offset + j]]; probs_data[j] = p_data ? p_data[cur_eid] : static_cast(1.); } RandomEngine::ThreadLocal()->Choice( num_samples[cur_et], probs, out_idx, replace); }; return pick_fn; } template inline NumPicksFn GetSamplingUniformNumPicksFn( int64_t num_samples, bool replace) { NumPicksFn num_picks_fn = [num_samples, replace]( IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data) { const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; if (replace) { return static_cast(len == 0 ? 0 : max_num_picks); } else { return std::min(static_cast(max_num_picks), len); } }; return num_picks_fn; } template inline PickFn GetSamplingUniformPickFn( int64_t num_samples, bool replace) { PickFn pick_fn = [num_samples, replace]( IdxType rowid, IdxType off, IdxType len, IdxType num_picks, const IdxType* col, const IdxType* data, IdxType* out_idx) { RandomEngine::ThreadLocal()->UniformChoice( num_picks, len, out_idx, replace); for (int64_t j = 0; j < num_picks; ++j) { out_idx[j] += off; } }; return pick_fn; } template inline EtypeRangePickFn GetSamplingUniformRangePickFn( const std::vector& num_samples, bool replace) { EtypeRangePickFn pick_fn = [num_samples, replace]( IdxType off, IdxType et_offset, IdxType cur_et, IdxType et_len, const std::vector& et_idx, const std::vector& et_eid, const IdxType* data, IdxType* out_idx) { RandomEngine::ThreadLocal()->UniformChoice( num_samples[cur_et], et_len, out_idx, replace); }; return pick_fn; } template inline NumPicksFn GetSamplingBiasedNumPicksFn( int64_t num_samples, IdArray split, FloatArray bias, bool replace) { NumPicksFn num_picks_fn = [num_samples, split, bias, replace]( IdxType rowid, IdxType off, IdxType len, const IdxType* col, const IdxType* data) { const int64_t max_num_picks = (num_samples == -1) ? len : num_samples; const int64_t num_tags = split->shape[1] - 1; const IdxType* tag_offset = split.Ptr() + rowid * split->shape[1]; const FloatType* bias_data = bias.Ptr(); IdxType nnz = 0; for (int64_t j = 0; j < num_tags; ++j) { if (bias_data[j] > 0) { nnz += tag_offset[j + 1] - tag_offset[j]; } } if (replace) { return static_cast(nnz == 0 ? 0 : max_num_picks); } else { return std::min(static_cast(max_num_picks), nnz); } }; return num_picks_fn; } template inline PickFn GetSamplingBiasedPickFn( int64_t num_samples, IdArray split, FloatArray bias, bool replace) { PickFn pick_fn = [num_samples, split, bias, replace]( IdxType rowid, IdxType off, IdxType len, IdxType num_picks, const IdxType* col, const IdxType* data, IdxType* out_idx) { const IdxType* tag_offset = split.Ptr() + rowid * split->shape[1]; RandomEngine::ThreadLocal()->BiasedChoice( num_picks, tag_offset, bias, out_idx, replace); for (int64_t j = 0; j < num_picks; ++j) { out_idx[j] += off; } }; return pick_fn; } } // namespace /////////////////////////////// CSR /////////////////////////////// template COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); CHECK(prob_or_mask.defined()); auto num_picks_fn = GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); auto pick_fn = GetSamplingPickFn(num_samples, prob_or_mask, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); } template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); template < DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> std::pair CSRRowWiseSamplingFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, std::vector* new_seed_nodes, int64_t num_samples, NDArray prob_or_mask, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); CHECK(prob_or_mask.defined()); auto num_picks_fn = GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); auto pick_fn = GetSamplingPickFn(num_samples, prob_or_mask, replace); return CSRRowWisePickFused( mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, num_picks_fn); } template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template std::pair CSRRowWiseSamplingFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, const std::vector& prob_or_mask, bool replace, bool rowwise_etype_sorted) { CHECK(prob_or_mask.size() == num_samples.size()) << "the number of probability tensors does not match the number of edge " "types."; for (auto& p : prob_or_mask) CHECK(p.defined()); auto pick_fn = GetSamplingRangePickFn( num_samples, prob_or_mask, replace); return CSRRowWisePerEtypePick( mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, prob_or_mask); } template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); auto num_picks_fn = GetSamplingUniformNumPicksFn(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); } template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix mat, IdArray rows, IdArray seed_mapping, std::vector* new_seed_nodes, int64_t num_samples, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); auto num_picks_fn = GetSamplingUniformNumPicksFn(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return CSRRowWisePickFused( mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, num_picks_fn); } template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template std::pair CSRRowWiseSamplingUniformFused( CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, bool replace, bool rowwise_etype_sorted) { auto pick_fn = GetSamplingUniformRangePickFn(num_samples, replace); return CSRRowWisePerEtypePick( mat, rows, eid2etype_offset, num_samples, replace, rowwise_etype_sorted, pick_fn, {}); } template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix, IdArray, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix, IdArray, const std::vector&, const std::vector&, bool, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset, FloatArray bias, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); auto num_picks_fn = GetSamplingBiasedNumPicksFn( num_samples, tag_offset, bias, replace); auto pick_fn = GetSamplingBiasedPickFn( num_samples, tag_offset, bias, replace); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); } template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); template COOMatrix CSRRowWiseSamplingBiased( CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); /////////////////////////////// COO /////////////////////////////// template COOMatrix COORowWiseSampling( COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); CHECK(prob_or_mask.defined()); auto num_picks_fn = GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); auto pick_fn = GetSamplingPickFn(num_samples, prob_or_mask, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); } template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWiseSampling( COOMatrix, IdArray, int64_t, NDArray, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, const std::vector& prob_or_mask, bool replace) { CHECK(prob_or_mask.size() == num_samples.size()) << "the number of probability tensors do not match the number of edge " "types."; for (auto& p : prob_or_mask) CHECK(p.defined()); auto pick_fn = GetSamplingRangePickFn( num_samples, prob_or_mask, replace); return COORowWisePerEtypePick( mat, rows, eid2etype_offset, num_samples, replace, pick_fn, prob_or_mask); } template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSampling( COOMatrix, IdArray, const std::vector&, const std::vector&, const std::vector&, bool); template COOMatrix COORowWiseSamplingUniform( COOMatrix mat, IdArray rows, int64_t num_samples, bool replace) { // If num_samples is -1, select all neighbors without replacement. replace = (replace && num_samples != -1); auto num_picks_fn = GetSamplingUniformNumPicksFn(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); return COORowWisePick(mat, rows, num_samples, replace, pick_fn, num_picks_fn); } template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); template COOMatrix COORowWiseSamplingUniform( COOMatrix, IdArray, int64_t, bool); template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, bool replace) { auto pick_fn = GetSamplingUniformRangePickFn(num_samples, replace); return COORowWisePerEtypePick( mat, rows, eid2etype_offset, num_samples, replace, pick_fn, {}); } template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix, IdArray, const std::vector&, const std::vector&, bool); template COOMatrix COORowWisePerEtypeSamplingUniform( COOMatrix, IdArray, const std::vector&, const std::vector&, bool); } // namespace impl } // namespace aten } // namespace dgl