randomwalk_cpu.cc 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/sampling/randomwalk_cpu.cc
 * \brief DGL sampler - CPU implementation of metapath-based random walk with OpenMP
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
9
#include <dgl/runtime/device_api.h>
10
#include <vector>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
11
#include <utility>
12
#include <algorithm>
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include "randomwalks_impl.h"
#include "randomwalks_cpu.h"
#include "metapath_randomwalk.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

template<DLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
27
std::pair<IdArray, IdArray> RandomWalk(
28
29
30
31
32
33
34
35
36
37
38
39
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob) {
  TerminatePredicate<IdxType> terminate =
    [] (IdxType *data, dgl_id_t curr, int64_t len) {
      return false;
    };

  return MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate);
}

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
template<DLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k) {
  CHECK(src->ctx.device_type == kDLCPU) << "IdArray needs be on CPU!";
  int64_t len = src->shape[0] / num_samples_per_node;
  IdxType* src_data = src.Ptr<IdxType>();
  const IdxType* dst_data = dst.Ptr<IdxType>();
  std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec;
  for (int64_t i = 0; i < len; ++i) {
    int64_t start_idx = (i * num_samples_per_node);
    int64_t end_idx = (start_idx + num_samples_per_node);
    IdxType dst_node = dst_data[start_idx];
    std::sort(src_data + start_idx, src_data + end_idx);
    int64_t cnt = 0;
    std::vector<std::pair<IdxType, IdxType>> vec;
    for (int64_t j = start_idx; j < end_idx; ++j) {
      if ((j != start_idx) && (src_data[j] != src_data[j-1])) {
        if (src_data[j-1] != -1) {
          vec.emplace_back(std::make_pair(cnt, src_data[j-1]));
        }
        cnt = 0;
      }
      ++cnt;
    }
    // add last count
    if (src_data[end_idx-1] != -1) {
      vec.emplace_back(std::make_pair(cnt, src_data[end_idx-1]));
    }
    std::sort(vec.begin(), vec.end(),
              std::greater<std::pair<IdxType, IdxType>>());
    int64_t len = std::min(vec.size(), static_cast<size_t>(k));
    for (int64_t j = 0; j < len; ++j) {
      auto pair_item = vec[j];
      res_src_vec.emplace_back(pair_item.second);
      res_dst_vec.emplace_back(dst_node);
      res_cnt_vec.emplace_back(pair_item.first);
    }
  }
  IdArray res_src = IdArray::Empty({static_cast<int64_t>(res_src_vec.size())},
      src->dtype, src->ctx);
  IdArray res_dst = IdArray::Empty({static_cast<int64_t>(res_dst_vec.size())},
      dst->dtype, dst->ctx);
  IdArray res_cnt = IdArray::Empty({static_cast<int64_t>(res_cnt_vec.size())},
      src->dtype, src->ctx);

  // copy data from vector to NDArray
  auto device = runtime::DeviceAPI::Get(src->ctx);
  device->CopyDataFromTo(static_cast<IdxType*>(res_src_vec.data()), 0,
      res_src.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_src_vec.size(),
      DGLContext{kDLCPU, 0}, res_src->ctx,
      res_src->dtype, 0);
  device->CopyDataFromTo(static_cast<IdxType*>(res_dst_vec.data()), 0,
      res_dst.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_dst_vec.size(),
      DGLContext{kDLCPU, 0}, res_dst->ctx,
      res_dst->dtype, 0);
  device->CopyDataFromTo(static_cast<IdxType*>(res_cnt_vec.data()), 0,
      res_cnt.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_cnt_vec.size(),
      DGLContext{kDLCPU, 0}, res_cnt->ctx,
      res_cnt->dtype, 0);

  return std::make_tuple(res_src, res_dst, res_cnt);
}

109
template
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
110
std::pair<IdArray, IdArray> RandomWalk<kDLCPU, int32_t>(
111
112
113
114
115
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob);
template
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
116
std::pair<IdArray, IdArray> RandomWalk<kDLCPU, int64_t>(
117
118
119
120
121
    const HeteroGraphPtr hg,
    const IdArray seeds,
    const TypeArray metapath,
    const std::vector<FloatArray> &prob);

122
123
124
125
126
127
128
129
130
131
132
133
134
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLCPU, int32_t>(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLCPU, int64_t>(
    const IdArray src,
    const IdArray dst,
    const int64_t num_samples_per_node,
    const int64_t k);

135
136
137
138
139
};  // namespace impl

};  // namespace sampling

};  // namespace dgl