"vscode:/vscode.git/clone" did not exist on "0d1b2c5a2cf974e1d83e9db9b304cc9f3c662337"
randomwalks.cc 5.4 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2018 by Contributors
3
 * \file graph/sampling/randomwalks.cc
4
5
6
7
8
9
10
11
12
13
 * \brief Dispatcher of different DGL random walks by device type
 */

#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/sampling/randomwalks.h>
#include <utility>
#include <tuple>
#include <vector>
14
#include "../../../c_api_common.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "randomwalks_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

void CheckRandomWalkInputs(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CHECK_INT(seeds, "seeds");
  CHECK_INT(metapath, "metapath");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_NDIM(metapath, 1, "metapath");
  for (uint64_t i = 0; i < prob.size(); ++i) {
    FloatArray p = prob[i];
    CHECK_FLOAT(p, "probability");
38
    if (p.GetSize() != 0)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
      CHECK_NDIM(p, 1, "probability");
  }
}

};  // namespace

std::pair<IdArray, TypeArray> RandomWalk(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);

  TypeArray vtypes;
  IdArray vids;
54
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalk", {
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
      vids = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
    });
  });

  return std::make_pair(vids, vtypes);
}

std::pair<IdArray, TypeArray> RandomWalkWithRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  CHECK(restart_prob >= 0 && restart_prob < 1) << "restart probability must belong to [0, 1)";

  TypeArray vtypes;
  IdArray vids;
75
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithRestart", {
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
      vids = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob);
    });
  });

  return std::make_pair(vids, vtypes);
}

std::pair<IdArray, TypeArray> RandomWalkWithStepwiseRestart(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  // TODO(BarclayII): check the elements of restart probability

  TypeArray vtypes;
  IdArray vids;
96
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithStepwiseRestart", {
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
      vids = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
          hg, seeds, metapath, prob, restart_prob);
    });
  });

  return std::make_pair(vids, vtypes);
}

};  // namespace sampling

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];

116
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
    List<Value> ret;
    ret.push_back(Value(MakeValue(result.first)));
    ret.push_back(Value(MakeValue(result.second)));
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    double restart_prob = args[4];

133
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    auto result = sampling::RandomWalkWithRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
    ret.push_back(Value(MakeValue(result.first)));
    ret.push_back(Value(MakeValue(result.second)));
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    FloatArray restart_prob = args[4];

151
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    auto result = sampling::RandomWalkWithStepwiseRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
    ret.push_back(Value(MakeValue(result.first)));
    ret.push_back(Value(MakeValue(result.second)));
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    IdArray vids = args[0];
    TypeArray vtypes = args[1];

    IdArray concat_vids, concat_vtypes, lengths, offsets;
    std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);
    std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);

    List<Value> ret;
    ret.push_back(Value(MakeValue(concat_vids)));
    ret.push_back(Value(MakeValue(concat_vtypes)));
    ret.push_back(Value(MakeValue(lengths)));
    ret.push_back(Value(MakeValue(offsets)));
    *rv = ret;
  });

};  // namespace dgl