negative_sampling.cc 2.5 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file array/cpu/negative_sampling.cc
 * @brief Uniform negative sampling on CSR.
5
6
7
8
9
 */

#include <dgl/array.h>
#include <dgl/array_iterator.h>
#include <dgl/random.h>
10
11
#include <dgl/runtime/parallel_for.h>

12
#include <algorithm>
13
#include <utility>
14
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace aten {
namespace impl {

21
template <DGLDeviceType XPU, typename IdType>
22
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
23
24
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
25
26
  const int64_t num_row = csr.num_rows;
  const int64_t num_col = csr.num_cols;
27
28
  const int64_t num_actual_samples =
      static_cast<int64_t>(num_samples * (1 + redundancy));
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
  IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
  IdType* row_data = row.Ptr<IdType>();
  IdType* col_data = col.Ptr<IdType>();

  parallel_for(0, num_actual_samples, 1, [&](int64_t b, int64_t e) {
    for (int64_t i = b; i < e; ++i) {
      for (int trial = 0; trial < num_trials; ++trial) {
        IdType u = RandomEngine::ThreadLocal()->RandInt(num_row);
        IdType v = RandomEngine::ThreadLocal()->RandInt(num_col);
        if (!(exclude_self_loops && (u == v)) && !CSRIsNonZero(csr, u, v)) {
          row_data[i] = u;
          col_data[i] = v;
          break;
        }
      }
    }
  });

  PairIterator<IdType> begin(row_data, col_data);
49
50
  PairIterator<IdType> end = std::remove_if(
      begin, begin + num_actual_samples,
51
52
      [](const std::pair<IdType, IdType>& val) { return val.first == -1; });
  if (!replace) {
53
54
55
56
57
58
59
    std::sort(
        begin, end,
        [](const std::pair<IdType, IdType>& a,
           const std::pair<IdType, IdType>& b) {
          return a.first < b.first ||
                 (a.first == b.first && a.second < b.second);
        });
60
61
    end = std::unique(begin, end);
  }
62
63
64
65
66
  int64_t num_sampled =
      std::min(static_cast<int64_t>(end - begin), num_samples);
  return {
      row.CreateView({num_sampled}, row->dtype),
      col.CreateView({num_sampled}, col->dtype)};
67
68
}

69
70
71
72
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
    kDGLCPU, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
    kDGLCPU, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
73
74
75
76

};  // namespace impl
};  // namespace aten
};  // namespace dgl