frequency_hashmap.cuh 2.21 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2021 by Contributors
 * \file graph/sampling/frequency_hashmap.cuh
 * \brief frequency hashmap - used to select top-k frequency edges of each node
 */

7
8
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
9
10
11

#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
12
#include <tuple>
13
14
15
16
17
18
19

namespace dgl {
namespace sampling {
namespace impl {

template<typename IdxType>
class DeviceEdgeHashmap {
20
 public:
21
22
23
24
25
26
27
28
  struct EdgeItem {
    IdxType src;
    IdxType cnt;
  };
  DeviceEdgeHashmap() = delete;
  DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst,
      IdxType* dst_unique_edges, EdgeItem *edge_hashmap):
    _num_dst(num_dst), _num_items_each_dst(num_items_each_dst),
29
    _dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {}
30
31
32
33
34
  // return the old cnt of this edge
  inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx);
  inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
  inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx);

35
 private:
36
37
38
39
40
41
42
  int64_t _num_dst;
  int64_t _num_items_each_dst;
  IdxType  *_dst_unique_edges;
  EdgeItem *_edge_hashmap;

  inline __device__ IdxType EdgeHash(const IdxType &id) const {
    return id % _num_items_each_dst;
43
  }
44
45
46
47
};

template<typename IdxType>
class FrequencyHashmap {
48
 public:
49
50
51
52
53
54
55
  static constexpr int64_t kDefaultEdgeTableScale = 3;
  FrequencyHashmap() = delete;
  FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream,
      int64_t edge_table_scale = kDefaultEdgeTableScale);
  ~FrequencyHashmap();
  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
  std::tuple<IdArray, IdArray, IdArray> Topk(
56
      const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
57
58
      const int64_t num_edges, const int64_t num_edges_per_node,
      const int64_t num_pick);
59
 private:
60
61
62
63
64
65
66
67
68
69
70
  DGLContext _ctx;
  cudaStream_t _stream;
  DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
  IdxType  *_dst_unique_edges;
  EdgeItem *_edge_hashmap;
};

};  // namespace impl
};  // namespace sampling
};  // namespace dgl

71
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_