randomwalk_cpu.cc 4.41 KB
Newer Older
1
2
3
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/sampling/randomwalk_cpu.cc
4
5
 * \brief DGL sampler - CPU implementation of metapath-based random walk with
 * OpenMP
6
7
8
9
 */

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
10
#include <dgl/runtime/device_api.h>
11

12
#include <algorithm>
13
14
15
#include <utility>
#include <vector>

16
#include "metapath_randomwalk.h"
17
18
#include "randomwalks_cpu.h"
#include "randomwalks_impl.h"
19
20
21
22
23
24
25
26
27
28

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace sampling {

namespace impl {

29
template <DGLDeviceType XPU, typename IdxType>
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
30
std::pair<IdArray, IdArray> RandomWalk(
31
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
32
    const std::vector<FloatArray> &prob) {
33
34
  TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr,
                                             int64_t len) { return false; };
35

36
37
  return MetapathBasedRandomWalk<XPU, IdxType>(
      hg, seeds, metapath, prob, terminate);
38
39
}

40
template <DGLDeviceType XPU, typename IdxType>
41
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
42
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
43
    const int64_t k) {
44
  CHECK(src->ctx.device_type == kDGLCPU) << "IdArray needs be on CPU!";
45
  int64_t len = src->shape[0] / num_samples_per_node;
46
47
  IdxType *src_data = src.Ptr<IdxType>();
  const IdxType *dst_data = dst.Ptr<IdxType>();
48
49
50
51
52
53
54
55
56
  std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec;
  for (int64_t i = 0; i < len; ++i) {
    int64_t start_idx = (i * num_samples_per_node);
    int64_t end_idx = (start_idx + num_samples_per_node);
    IdxType dst_node = dst_data[start_idx];
    std::sort(src_data + start_idx, src_data + end_idx);
    int64_t cnt = 0;
    std::vector<std::pair<IdxType, IdxType>> vec;
    for (int64_t j = start_idx; j < end_idx; ++j) {
57
58
59
      if ((j != start_idx) && (src_data[j] != src_data[j - 1])) {
        if (src_data[j - 1] != -1) {
          vec.emplace_back(std::make_pair(cnt, src_data[j - 1]));
60
61
62
63
64
65
        }
        cnt = 0;
      }
      ++cnt;
    }
    // add last count
66
67
    if (src_data[end_idx - 1] != -1) {
      vec.emplace_back(std::make_pair(cnt, src_data[end_idx - 1]));
68
    }
69
70
    std::sort(
        vec.begin(), vec.end(), std::greater<std::pair<IdxType, IdxType>>());
71
72
73
74
75
76
77
78
    int64_t len = std::min(vec.size(), static_cast<size_t>(k));
    for (int64_t j = 0; j < len; ++j) {
      auto pair_item = vec[j];
      res_src_vec.emplace_back(pair_item.second);
      res_dst_vec.emplace_back(dst_node);
      res_cnt_vec.emplace_back(pair_item.first);
    }
  }
79
80
81
82
83
84
  IdArray res_src = IdArray::Empty(
      {static_cast<int64_t>(res_src_vec.size())}, src->dtype, src->ctx);
  IdArray res_dst = IdArray::Empty(
      {static_cast<int64_t>(res_dst_vec.size())}, dst->dtype, dst->ctx);
  IdArray res_cnt = IdArray::Empty(
      {static_cast<int64_t>(res_cnt_vec.size())}, src->dtype, src->ctx);
85
86
87

  // copy data from vector to NDArray
  auto device = runtime::DeviceAPI::Get(src->ctx);
88
89
90
91
92
93
94
95
96
97
98
99
  device->CopyDataFromTo(
      static_cast<IdxType *>(res_src_vec.data()), 0, res_src.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_src_vec.size(), DGLContext{kDGLCPU, 0},
      res_src->ctx, res_src->dtype);
  device->CopyDataFromTo(
      static_cast<IdxType *>(res_dst_vec.data()), 0, res_dst.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_dst_vec.size(), DGLContext{kDGLCPU, 0},
      res_dst->ctx, res_dst->dtype);
  device->CopyDataFromTo(
      static_cast<IdxType *>(res_cnt_vec.data()), 0, res_cnt.Ptr<IdxType>(), 0,
      sizeof(IdxType) * res_cnt_vec.size(), DGLContext{kDGLCPU, 0},
      res_cnt->ctx, res_cnt->dtype);
100
101
102
103

  return std::make_tuple(res_src, res_dst, res_cnt);
}

104
105
template std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int32_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
106
    const std::vector<FloatArray> &prob);
107
108
template std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int64_t>(
    const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
109
110
    const std::vector<FloatArray> &prob);

111
112
113
template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCPU, int32_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
114
    const int64_t k);
115
116
117
template std::tuple<IdArray, IdArray, IdArray>
SelectPinSageNeighbors<kDGLCPU, int64_t>(
    const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
118
119
    const int64_t k);

120
121
122
123
124
};  // namespace impl

};  // namespace sampling

};  // namespace dgl