node2vec.cc 1.83 KB
Newer Older
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/*!
 *  Copyright (c) 2021 by Contributors
 * \file graph/sampling/node2vec.cc
 * \brief Dispatcher of DGL node2vec random walks
 */

#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>

#include "../../../c_api_common.h"
#include "node2vec_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

void CheckNode2vecInputs(const HeteroGraphPtr hg, const IdArray seeds,
                         const double p, const double q,
                         const int64_t walk_length, const FloatArray &prob) {
  CHECK_INT(seeds, "seeds");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_FLOAT(prob, "probability");
  CHECK_NDIM(prob, 1, "probability");
}

std::pair<IdArray, IdArray> Node2vec(
    const HeteroGraphPtr hg, const IdArray seeds, const double p,
    const double q, const int64_t walk_length,
    const FloatArray &prob) {
  CheckNode2vecInputs(hg, seeds, p, q, walk_length, prob);

  std::pair<IdArray, IdArray> result;
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "Node2vec", {
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      result = impl::Node2vec<XPU, IdxType>(hg, seeds, p, q, walk_length, prob);
    });
  });

  return result;
}

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingNode2vec")
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef hg = args[0];
      IdArray seeds = args[1];
      double p = args[2];
      double q = args[3];
      int64_t walk_length = args[4];
      FloatArray prob = args[5];

      auto result =
          sampling::Node2vec(hg.sptr(), seeds, p, q, walk_length, prob);

      List<Value> ret;
      ret.push_back(Value(MakeValue(result.first)));
      ret.push_back(Value(MakeValue(result.second)));
      *rv = ret;
    });

}  // namespace

}  // namespace sampling

}  // namespace dgl