frequency_hashmap.cuh 2.25 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2021 by Contributors
 * \file graph/sampling/frequency_hashmap.cuh
 * \brief frequency hashmap - used to select top-k frequency edges of each node
 */

7
8
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
9
10
11

#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
12

13
#include <tuple>
14
15
16
17
18

namespace dgl {
namespace sampling {
namespace impl {

19
template <typename IdxType>
20
class DeviceEdgeHashmap {
21
 public:
22
23
24
25
26
  struct EdgeItem {
    IdxType src;
    IdxType cnt;
  };
  DeviceEdgeHashmap() = delete;
27
28
29
30
31
32
33
  DeviceEdgeHashmap(
      int64_t num_dst, int64_t num_items_each_dst, IdxType *dst_unique_edges,
      EdgeItem *edge_hashmap)
      : _num_dst(num_dst),
        _num_items_each_dst(num_items_each_dst),
        _dst_unique_edges(dst_unique_edges),
        _edge_hashmap(edge_hashmap) {}
34
  // return the old cnt of this edge
35
36
  inline __device__ IdxType
  InsertEdge(const IdxType &src, const IdxType &dst_idx);
37
  inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
38
39
  inline __device__ IdxType
  GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
40

41
 private:
42
43
  int64_t _num_dst;
  int64_t _num_items_each_dst;
44
  IdxType *_dst_unique_edges;
45
46
47
48
  EdgeItem *_edge_hashmap;

  inline __device__ IdxType EdgeHash(const IdxType &id) const {
    return id % _num_items_each_dst;
49
  }
50
51
};

52
template <typename IdxType>
53
class FrequencyHashmap {
54
 public:
55
56
  static constexpr int64_t kDefaultEdgeTableScale = 3;
  FrequencyHashmap() = delete;
57
58
59
  FrequencyHashmap(
      int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
      cudaStream_t stream, int64_t edge_table_scale = kDefaultEdgeTableScale);
60
61
62
  ~FrequencyHashmap();
  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
  std::tuple<IdArray, IdArray, IdArray> Topk(
63
      const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
64
65
      const int64_t num_edges, const int64_t num_edges_per_node,
      const int64_t num_pick);
66

67
 private:
68
69
70
  DGLContext _ctx;
  cudaStream_t _stream;
  DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
71
  IdxType *_dst_unique_edges;
72
73
74
75
76
77
78
  EdgeItem *_edge_hashmap;
};

};  // namespace impl
};  // namespace sampling
};  // namespace dgl

79
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_