randomwalks.cc 7.92 KB
Newer Older
1
/**
2
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/sampling/randomwalks.cc
 * @brief Dispatcher of different DGL random walks by device type
5
6
7
 */

#include <dgl/array.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
10
#include <dgl/sampling/randomwalks.h>
11

12
#include <tuple>
13
#include <utility>
14
#include <vector>
15

16
#include "../../../c_api_common.h"
17
18
19
20
21
22
23
24
25
26
27
28
#include "randomwalks_impl.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

namespace sampling {

namespace {

void CheckRandomWalkInputs(
29
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
30
31
32
33
34
    const std::vector<FloatArray> &prob) {
  CHECK_INT(seeds, "seeds");
  CHECK_INT(metapath, "metapath");
  CHECK_NDIM(seeds, 1, "seeds");
  CHECK_NDIM(metapath, 1, "metapath");
35
36
37
38
  // (Xin): metapath is copied to GPU in CUDA random walk code
  // CHECK_SAME_CONTEXT(seeds, metapath);

  if (hg->IsPinned()) {
39
40
41
    CHECK_EQ(seeds->ctx.device_type, kDGLCUDA)
        << "Expected seeds (" << seeds->ctx << ")"
        << " to be on the GPU when the graph is pinned.";
42
  } else if (hg->Context() != seeds->ctx) {
43
44
45
    LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")"
               << " to have the same "
               << "context as graph (" << hg->Context() << ").";
46
  }
47
48
  for (uint64_t i = 0; i < prob.size(); ++i) {
    FloatArray p = prob[i];
49
50
51
52
    CHECK_EQ(hg->Context(), p->ctx)
        << "Expected prob (" << p->ctx << ")"
        << " to have the same "
        << "context as graph (" << hg->Context() << ").";
53
    CHECK_FLOAT(p, "probability");
54
55
    if (p.GetSize() != 0) {
      CHECK_EQ(hg->IsPinned(), p.IsPinned())
56
          << "The prob array should have the same pinning status as the graph";
57
      CHECK_NDIM(p, 1, "probability");
58
    }
59
60
61
62
63
  }
}

};  // namespace

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
64
std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
65
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
66
67
68
69
    const std::vector<FloatArray> &prob) {
  CheckRandomWalkInputs(hg, seeds, metapath, prob);

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
70
  std::pair<IdArray, IdArray> result;
71
  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalk", {
72
73
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
74
      result = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
75
76
77
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
78
  return std::make_tuple(result.first, result.second, vtypes);
79
80
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
81
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
82
83
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, double restart_prob) {
84
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
85
86
  CHECK(restart_prob >= 0 && restart_prob < 1)
      << "restart probability must belong to [0, 1)";
87
88

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
89
  std::pair<IdArray, IdArray> result;
90
  ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithRestart", {
91
92
    ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
      vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
93
94
      result = impl::RandomWalkWithRestart<XPU, IdxType>(
          hg, seeds, metapath, prob, restart_prob);
95
96
97
    });
  });

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
98
  return std::make_tuple(result.first, result.second, vtypes);
99
100
}

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
101
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
102
103
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
    const std::vector<FloatArray> &prob, FloatArray restart_prob) {
104
105
106
107
  CheckRandomWalkInputs(hg, seeds, metapath, prob);
  // TODO(BarclayII): check the elements of restart probability

  TypeArray vtypes;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
108
  std::pair<IdArray, IdArray> result;
109
110
111
112
113
114
115
116
  ATEN_XPU_SWITCH_CUDA(
      seeds->ctx.device_type, XPU, "RandomWalkWithStepwiseRestart", {
        ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
          vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
          result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
              hg, seeds, metapath, prob, restart_prob);
        });
      });
117

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
118
  return std::make_tuple(result.first, result.second, vtypes);
119
120
}

121
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
122
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
123
    const int64_t k) {
124
125
126
127
  assert(
      (src->ndim == 1) && (dst->ndim == 1) &&
      (src->shape[0] % num_samples_per_node == 0) &&
      (src->shape[0] == dst->shape[0]));
128
129
  std::tuple<IdArray, IdArray, IdArray> result;

130
  ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
131
    ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
132
133
      result = impl::SelectPinSageNeighbors<XPU, IdxType>(
          src, dst, num_samples_per_node, k);
134
135
136
137
138
139
    });
  });

  return result;
}

140
141
142
};  // namespace sampling

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef hg = args[0];
      IdArray seeds = args[1];
      TypeArray metapath = args[2];
      List<Value> prob = args[3];

      const auto &prob_vec = ListValueToVector<FloatArray>(prob);

      auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
      List<Value> ret;
      ret.push_back(Value(MakeValue(std::get<0>(result))));
      ret.push_back(Value(MakeValue(std::get<1>(result))));
      ret.push_back(Value(MakeValue(std::get<2>(result))));
      *rv = ret;
    });
158
159

DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors")
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      IdArray src = args[0];
      IdArray dst = args[1];
      int64_t num_travelsals = static_cast<int64_t>(args[2]);
      int64_t k = static_cast<int64_t>(args[3]);

      auto result =
          sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);

      List<Value> ret;
      ret.push_back(Value(MakeValue(std::get<0>(result))));
      ret.push_back(Value(MakeValue(std::get<1>(result))));
      ret.push_back(Value(MakeValue(std::get<2>(result))));
      *rv = ret;
    });
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
DGL_REGISTER_GLOBAL(
    "sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef hg = args[0];
      IdArray seeds = args[1];
      TypeArray metapath = args[2];
      List<Value> prob = args[3];
      double restart_prob = args[4];

      const auto &prob_vec = ListValueToVector<FloatArray>(prob);

      auto result = sampling::RandomWalkWithRestart(
          hg.sptr(), seeds, metapath, prob_vec, restart_prob);
      List<Value> ret;
      ret.push_back(Value(MakeValue(std::get<0>(result))));
      ret.push_back(Value(MakeValue(std::get<1>(result))));
      ret.push_back(Value(MakeValue(std::get<2>(result))));
      *rv = ret;
    });
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
DGL_REGISTER_GLOBAL(
    "sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart")
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      HeteroGraphRef hg = args[0];
      IdArray seeds = args[1];
      TypeArray metapath = args[2];
      List<Value> prob = args[3];
      FloatArray restart_prob = args[4];

      const auto &prob_vec = ListValueToVector<FloatArray>(prob);

      auto result = sampling::RandomWalkWithStepwiseRestart(
          hg.sptr(), seeds, metapath, prob_vec, restart_prob);
      List<Value> ret;
      ret.push_back(Value(MakeValue(std::get<0>(result))));
      ret.push_back(Value(MakeValue(std::get<1>(result))));
      ret.push_back(Value(MakeValue(std::get<2>(result))));
      *rv = ret;
    });
215
216

DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces")
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      IdArray vids = args[0];
      TypeArray vtypes = args[1];

      IdArray concat_vids, concat_vtypes, lengths, offsets;
      std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);
      std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);

      List<Value> ret;
      ret.push_back(Value(MakeValue(concat_vids)));
      ret.push_back(Value(MakeValue(concat_vtypes)));
      ret.push_back(Value(MakeValue(lengths)));
      ret.push_back(Value(MakeValue(offsets)));
      *rv = ret;
    });
232
233

};  // namespace dgl