randomwalks.cc 6.96 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2018 by Contributors
3
 * \file graph/sampling/randomwalks.cc
4
5
6
7
8
9
10
11
12
13
 * \brief Dispatcher of different DGL random walks by device type
 */

#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/sampling/randomwalks.h>
#include <utility>
#include <tuple>
#include <vector>
14
#include "../../../c_api_common.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "randomwalks_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

void CheckRandomWalkInputs(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CHECK_INT(seeds, "seeds");
  CHECK_INT(metapath, "metapath");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_NDIM(metapath, 1, "metapath");
  for (uint64_t i = 0; i < prob.size(); ++i) {
    FloatArray p = prob[i];
    CHECK_FLOAT(p, "probability");
38
    if (p.GetSize() != 0)
39
40
41
42
43
44
      CHECK_NDIM(p, 1, "probability");
  }
}

};  // namespace

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
45
std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
46
47
48
49
50
51
52
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
53
  std::pair<IdArray, IdArray> result;
54
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalk", {
55
56
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
57
      result = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
58
59
60
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
61
  return std::make_tuple(result.first, result.second, vtypes);
62
63
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
64
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
65
66
67
68
69
70
71
72
73
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  CHECK(restart_prob >= 0 && restart_prob < 1) << "restart probability must belong to [0, 1)";

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
74
  std::pair<IdArray, IdArray> result;
75
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithRestart", {
76
77
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
78
      result = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob);
79
80
81
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
82
  return std::make_tuple(result.first, result.second, vtypes);
83
84
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
85
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
86
87
88
89
90
91
92
93
94
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  // TODO(BarclayII): check the elements of restart probability

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
95
  std::pair<IdArray, IdArray> result;
96
  ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithStepwiseRestart", {
97
98
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
99
      result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
100
101
102
103
          hg, seeds, metapath, prob, restart_prob);
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
104
  return std::make_tuple(result.first, result.second, vtypes);
105
106
}

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k) {
  assert((src->ndim == 1) && (dst->ndim == 1)
          && (src->shape[0] % num_samples_per_node == 0)
          && (src->shape[0] == dst->shape[0]));
  std::tuple<IdArray, IdArray, IdArray> result;

  ATEN_XPU_SWITCH((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
    ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
      result = impl::SelectPinSageNeighbors<XPU, IdxType>(src, dst, num_samples_per_node, k);
    });
  });

  return result;
}

126
127
128
129
130
131
132
133
134
};  // namespace sampling

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];

135
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
136
137
138

    auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
139
140
141
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
142
    *rv = ret;
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  });

DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    IdArray src = args[0];
    IdArray dst = args[1];
    int64_t num_travelsals = static_cast<int64_t>(args[2]);
    int64_t k = static_cast<int64_t>(args[3]);

    auto result = sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);

    List<Value> ret;
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
    *rv = ret;
159
160
161
162
163
164
165
166
167
168
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    double restart_prob = args[4];

169
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
170
171
172
173

    auto result = sampling::RandomWalkWithRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
174
175
176
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
177
178
179
180
181
182
183
184
185
186
187
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    FloatArray restart_prob = args[4];

188
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
189
190
191
192

    auto result = sampling::RandomWalkWithStepwiseRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
193
194
195
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    IdArray vids = args[0];
    TypeArray vtypes = args[1];

    IdArray concat_vids, concat_vtypes, lengths, offsets;
    std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);
    std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);

    List<Value> ret;
    ret.push_back(Value(MakeValue(concat_vids)));
    ret.push_back(Value(MakeValue(concat_vtypes)));
    ret.push_back(Value(MakeValue(lengths)));
    ret.push_back(Value(MakeValue(offsets)));
    *rv = ret;
  });

};  // namespace dgl