node2vec.cc 1.79 KB
Newer Older
1
/**
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file graph/sampling/node2vec.cc
 * @brief Dispatcher of DGL node2vec random walks
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
 */

#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>

#include "../../../c_api_common.h"
#include "node2vec_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

23
24
25
void CheckNode2vecInputs(
    const HeteroGraphPtr hg, const IdArray seeds, const double p,
    const double q, const int64_t walk_length, const FloatArray &prob) {
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
26
27
28
29
30
31
32
33
  CHECK_INT(seeds, "seeds");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_FLOAT(prob, "probability");
  CHECK_NDIM(prob, 1, "probability");
}

std::pair<IdArray, IdArray> Node2vec(
    const HeteroGraphPtr hg, const IdArray seeds, const double p,
34
    const double q, const int64_t walk_length, const FloatArray &prob) {
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  CheckNode2vecInputs(hg, seeds, p, q, walk_length, prob);

  std::pair<IdArray, IdArray> result;
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "Node2vec", {
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      result = impl::Node2vec<XPU, IdxType>(hg, seeds, p, q, walk_length, prob);
    });
  });

  return result;
}

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingNode2vec")
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef hg = args[0];
      IdArray seeds = args[1];
      double p = args[2];
      double q = args[3];
      int64_t walk_length = args[4];
      FloatArray prob = args[5];

      auto result =
          sampling::Node2vec(hg.sptr(), seeds, p, q, walk_length, prob);

      List<Value> ret;
      ret.push_back(Value(MakeValue(result.first)));
      ret.push_back(Value(MakeValue(result.second)));
      *rv = ret;
    });

}  // namespace

}  // namespace sampling

}  // namespace dgl