/*! * Copyright (c) 2021 by Contributors * \file array/cpu/negative_sampling.cc * \brief Uniform negative sampling on CSR. */ #include #include #include #include #include #include using namespace dgl::runtime; namespace dgl { namespace aten { namespace impl { template std::pair CSRGlobalUniformNegativeSampling( const CSRMatrix& csr, int64_t num_samples, int num_trials, bool exclude_self_loops, bool replace, double redundancy) { const int64_t num_row = csr.num_rows; const int64_t num_col = csr.num_cols; const int64_t num_actual_samples = static_cast(num_samples * (1 + redundancy)); IdArray row = Full(-1, num_actual_samples, csr.indptr->ctx); IdArray col = Full(-1, num_actual_samples, csr.indptr->ctx); IdType* row_data = row.Ptr(); IdType* col_data = col.Ptr(); parallel_for(0, num_actual_samples, 1, [&](int64_t b, int64_t e) { for (int64_t i = b; i < e; ++i) { for (int trial = 0; trial < num_trials; ++trial) { IdType u = RandomEngine::ThreadLocal()->RandInt(num_row); IdType v = RandomEngine::ThreadLocal()->RandInt(num_col); if (!(exclude_self_loops && (u == v)) && !CSRIsNonZero(csr, u, v)) { row_data[i] = u; col_data[i] = v; break; } } } }); PairIterator begin(row_data, col_data); PairIterator end = std::remove_if( begin, begin + num_actual_samples, [](const std::pair& val) { return val.first == -1; }); if (!replace) { std::sort( begin, end, [](const std::pair& a, const std::pair& b) { return a.first < b.first || (a.first == b.first && a.second < b.second); }); end = std::unique(begin, end); } int64_t num_sampled = std::min(static_cast(end - begin), num_samples); return { row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)}; } template std::pair CSRGlobalUniformNegativeSampling< kDGLCPU, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double); template std::pair CSRGlobalUniformNegativeSampling< kDGLCPU, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double); }; // namespace impl }; // namespace aten }; // namespace dgl