negative_sampling.cc 2.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 *  Copyright (c) 2021 by Contributors
 * \file array/cpu/negative_sampling.cc
 * \brief Uniform negative sampling on CSR.
 */

#include <dgl/array.h>
#include <dgl/array_iterator.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/random.h>
#include <utility>
#include <algorithm>

using namespace dgl::runtime;

namespace dgl {
namespace aten {
namespace impl {

20
template <DGLDeviceType XPU, typename IdType>
21
22
23
24
25
26
27
28
29
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
    const CSRMatrix &csr,
    int64_t num_samples,
    int num_trials,
    bool exclude_self_loops,
    bool replace,
    double redundancy) {
  const int64_t num_row = csr.num_rows;
  const int64_t num_col = csr.num_cols;
30
  const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy));
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
  IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
  IdType* row_data = row.Ptr<IdType>();
  IdType* col_data = col.Ptr<IdType>();

  parallel_for(0, num_actual_samples, 1, [&](int64_t b, int64_t e) {
    for (int64_t i = b; i < e; ++i) {
      for (int trial = 0; trial < num_trials; ++trial) {
        IdType u = RandomEngine::ThreadLocal()->RandInt(num_row);
        IdType v = RandomEngine::ThreadLocal()->RandInt(num_col);
        if (!(exclude_self_loops && (u == v)) && !CSRIsNonZero(csr, u, v)) {
          row_data[i] = u;
          col_data[i] = v;
          break;
        }
      }
    }
  });

  PairIterator<IdType> begin(row_data, col_data);
  PairIterator<IdType> end = std::remove_if(begin, begin + num_actual_samples,
      [](const std::pair<IdType, IdType>& val) { return val.first == -1; });
  if (!replace) {
    std::sort(begin, end,
        [](const std::pair<IdType, IdType>& a, const std::pair<IdType, IdType>& b) {
          return a.first < b.first || (a.first == b.first && a.second < b.second);
        });;
    end = std::unique(begin, end);
  }
60
  int64_t num_sampled = std::min(static_cast<int64_t>(end - begin), num_samples);
61
62
63
  return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)};
}

64
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>(
65
    const CSRMatrix&, int64_t, int, bool, bool, double);
66
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>(
67
68
69
70
71
    const CSRMatrix&, int64_t, int, bool, bool, double);

};  // namespace impl
};  // namespace aten
};  // namespace dgl