global_uniform.cc 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2021 by Contributors
 * \file graph/sampling/negative/global_uniform.cc
 * \brief Global uniform negative sampling.
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
11
12
#include <dgl/sampling/negative.h>

13
#include <utility>
14

15
16
17
18
19
20
21
22
23
#include "../../../c_api_common.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {
namespace sampling {

std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
24
25
    HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE);
  if (format == SparseFormat::kCSC) {
    CSRMatrix csc = hg->GetCSCMatrix(etype);
    CSRSort_(&csc);
    std::pair<IdArray, IdArray> result = CSRGlobalUniformNegativeSampling(
        csc, num_samples, num_trials, exclude_self_loops, replace, redundancy);
    // reverse the pair since it is CSC
    return {result.second, result.first};
  } else if (format == SparseFormat::kCSR) {
    CSRMatrix csr = hg->GetCSRMatrix(etype);
    CSRSort_(&csr);
    return CSRGlobalUniformNegativeSampling(
        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
  } else {
40
41
    LOG(FATAL)
        << "COO format is not supported in global uniform negative sampling";
42
43
44
45
46
    return {IdArray(), IdArray()};
  }
}

DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling")
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      HeteroGraphRef hg = args[0];
      dgl_type_t etype = args[1];
      CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
      int64_t num_samples = args[2];
      int num_trials = args[3];
      bool exclude_self_loops = args[4];
      bool replace = args[5];
      double redundancy = args[6];
      List<Value> result;
      std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling(
          hg.sptr(), etype, num_samples, num_trials, exclude_self_loops,
          replace, redundancy);
      result.push_back(Value(MakeValue(ret.first)));
      result.push_back(Value(MakeValue(ret.second)));
      *rv = result;
    });
64
65
66

};  // namespace sampling
};  // namespace dgl