metapath.cc 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/sampler/metapath.cc
 * \brief Metapath sampling
 */

#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/packed_func_ext.h>
#include "../../c_api_common.h"
#include "randomwalk.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

/*!
 * \brief Random walk based on the given metapath.
 *
 * \param hg The heterograph
 * \param etypes The metapath as an array of edge type IDs
 * \param seeds The array of starting vertices for random walks
 * \param num_traces Number of traces to generate for each starting vertex
 * \note The metapath should have the same starting and ending node type.
 */
RandomWalkTracesPtr MetapathRandomWalk(
    const HeteroGraphPtr hg,
    const IdArray etypes,
    const IdArray seeds,
    int num_traces) {
  const auto metagraph = hg->meta_graph();
  uint64_t num_etypes = etypes->shape[0];
  uint64_t num_seeds = seeds->shape[0];
  const dgl_type_t *etype_data = static_cast<dgl_type_t *>(etypes->data);
  const dgl_id_t *seed_data = static_cast<dgl_id_t *>(seeds->data);

  std::vector<dgl_id_t> vertices;
  std::vector<size_t> trace_lengths, trace_counts;

  // TODO(quan): use omp to parallelize this loop
  for (uint64_t seed_id = 0; seed_id < num_seeds; ++seed_id) {
    int curr_num_traces = 0;

    for (; curr_num_traces < num_traces; ++curr_num_traces) {
      dgl_id_t curr = seed_data[seed_id];

      size_t trace_length = 0;

      for (size_t i = 0; i < num_etypes; ++i) {
        const auto &succ = hg->SuccVec(etype_data[i], curr);
        if (succ.size() == 0)
          break;
        curr = succ[RandomEngine::ThreadLocal()->RandInt(succ.size())];
        vertices.push_back(curr);
        ++trace_length;
      }

      trace_lengths.push_back(trace_length);
    }

    trace_counts.push_back(curr_num_traces);
  }

  RandomWalkTraces *tl = new RandomWalkTraces;
  tl->vertices = VecToIdArray(vertices);
  tl->trace_lengths = VecToIdArray(trace_lengths);
  tl->trace_counts = VecToIdArray(trace_counts);

  return RandomWalkTracesPtr(tl);
}

};  // namespace

DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    const HeteroGraphRef hg = args[0];
    const IdArray etypes = args[1];
    const IdArray seeds = args[2];
    int num_traces = args[3];

    const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces);
    *rv = RandomWalkTracesRef(tl);
  });

};  // namespace sampling

};  // namespace dgl