randomwalks.cc 7.7 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2018 by Contributors
3
 * \file graph/sampling/randomwalks.cc
4
5
6
7
8
9
10
11
12
13
 * \brief Dispatcher of different DGL random walks by device type
 */

#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/sampling/randomwalks.h>
#include <utility>
#include <tuple>
#include <vector>
14
#include "../../../c_api_common.h"
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "randomwalks_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

void CheckRandomWalkInputs(
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CHECK_INT(seeds, "seeds");
  CHECK_INT(metapath, "metapath");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_NDIM(metapath, 1, "metapath");
35
36
37
38
39
40
41
42
43
44
  // (Xin): metapath is copied to GPU in CUDA random walk code
  // CHECK_SAME_CONTEXT(seeds, metapath);

  if (hg->IsPinned()) {
    CHECK_EQ(seeds->ctx.device_type, kDLGPU) << "Expected seeds (" << seeds->ctx << ")" \
      << " to be on the GPU when the graph is pinned.";
  } else if (hg->Context() != seeds->ctx) {
    LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")" << " to have the same " \
      << "context as graph (" << hg->Context() << ").";
  }
45
46
  for (uint64_t i = 0; i < prob.size(); ++i) {
    FloatArray p = prob[i];
47
48
    CHECK_EQ(hg->Context(), p->ctx) << "Expected prob (" << p->ctx << ")" << " to have the same " \
      << "context as graph (" << hg->Context() << ").";
49
    CHECK_FLOAT(p, "probability");
50
51
52
    if (p.GetSize() != 0) {
      CHECK_EQ(hg->IsPinned(), p.IsPinned())
        << "The prob array should have the same pinning status as the graph";
53
      CHECK_NDIM(p, 1, "probability");
54
    }
55
56
57
58
59
  }
}

};  // namespace

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
60
std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
61
62
63
64
65
66
67
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
68
  std::pair<IdArray, IdArray> result;
69
  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalk", {
70
71
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
72
      result = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
73
74
75
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
76
  return std::make_tuple(result.first, result.second, vtypes);
77
78
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
79
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
80
81
82
83
84
85
86
87
88
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    double restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  CHECK(restart_prob >= 0 && restart_prob < 1) << "restart probability must belong to [0, 1)";

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
89
  std::pair<IdArray, IdArray> result;
90
  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithRestart", {
91
92
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
93
      result = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob);
94
95
96
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
97
  return std::make_tuple(result.first, result.second, vtypes);
98
99
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
100
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
101
102
103
104
105
106
107
108
109
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob,
    FloatArray restart_prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  // TODO(BarclayII): check the elements of restart probability

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
110
  std::pair<IdArray, IdArray> result;
111
  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithStepwiseRestart", {
112
113
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
114
      result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
115
116
117
118
          hg, seeds, metapath, prob, restart_prob);
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
119
  return std::make_tuple(result.first, result.second, vtypes);
120
121
}

122
123
124
125
126
127
128
129
130
131
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k) {
  assert((src->ndim == 1) && (dst->ndim == 1)
          && (src->shape[0] % num_samples_per_node == 0)
          && (src->shape[0] == dst->shape[0]));
  std::tuple<IdArray, IdArray, IdArray> result;

132
  ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
133
134
135
136
137
138
139
140
    ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
      result = impl::SelectPinSageNeighbors<XPU, IdxType>(src, dst, num_samples_per_node, k);
    });
  });

  return result;
}

141
142
143
144
145
146
147
148
149
};  // namespace sampling

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];

150
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
151
152
153

    auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
154
155
156
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
157
    *rv = ret;
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  });

DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    IdArray src = args[0];
    IdArray dst = args[1];
    int64_t num_travelsals = static_cast<int64_t>(args[2]);
    int64_t k = static_cast<int64_t>(args[3]);

    auto result = sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);

    List<Value> ret;
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
    *rv = ret;
174
175
176
177
178
179
180
181
182
183
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    double restart_prob = args[4];

184
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
185
186
187
188

    auto result = sampling::RandomWalkWithRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
189
190
191
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
192
193
194
195
196
197
198
199
200
201
202
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    HeteroGraphRef hg = args[0];
    IdArray seeds = args[1];
    TypeArray metapath = args[2];
    List<Value> prob = args[3];
    FloatArray restart_prob = args[4];

203
    const auto& prob_vec = ListValueToVector<FloatArray>(prob);
204
205
206
207

    auto result = sampling::RandomWalkWithStepwiseRestart(
        hg.sptr(), seeds, metapath, prob_vec, restart_prob);
    List<Value> ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
208
209
210
    ret.push_back(Value(MakeValue(std::get<0>(result))));
    ret.push_back(Value(MakeValue(std::get<1>(result))));
    ret.push_back(Value(MakeValue(std::get<2>(result))));
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    *rv = ret;
  });

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    IdArray vids = args[0];
    TypeArray vtypes = args[1];

    IdArray concat_vids, concat_vtypes, lengths, offsets;
    std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);
    std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);

    List<Value> ret;
    ret.push_back(Value(MakeValue(concat_vids)));
    ret.push_back(Value(MakeValue(concat_vtypes)));
    ret.push_back(Value(MakeValue(lengths)));
    ret.push_back(Value(MakeValue(offsets)));
    *rv = ret;
  });

};  // namespace dgl