frequency_hashmap.cuh 2.33 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file graph/sampling/frequency_hashmap.cuh
 * @brief frequency hashmap - used to select top-k frequency edges of each node
7
8
 */

9
10
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_
11
12
13

#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
14

15
#include <tuple>
16
17
18
19
20

namespace dgl {
namespace sampling {
namespace impl {

21
template <typename IdxType>
22
class DeviceEdgeHashmap {
23
 public:
24
25
26
27
28
  struct EdgeItem {
    IdxType src;
    IdxType cnt;
  };
  DeviceEdgeHashmap() = delete;
29
30
31
32
33
34
35
  DeviceEdgeHashmap(
      int64_t num_dst, int64_t num_items_each_dst, IdxType *dst_unique_edges,
      EdgeItem *edge_hashmap)
      : _num_dst(num_dst),
        _num_items_each_dst(num_items_each_dst),
        _dst_unique_edges(dst_unique_edges),
        _edge_hashmap(edge_hashmap) {}
36
  // return the old cnt of this edge
37
38
  inline __device__ IdxType
  InsertEdge(const IdxType &src, const IdxType &dst_idx);
39
  inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
40
41
  inline __device__ IdxType
  GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
42

43
 private:
44
45
  int64_t _num_dst;
  int64_t _num_items_each_dst;
46
  IdxType *_dst_unique_edges;
47
48
49
50
  EdgeItem *_edge_hashmap;

  inline __device__ IdxType EdgeHash(const IdxType &id) const {
    return id % _num_items_each_dst;
51
  }
52
53
};

54
template <typename IdxType>
55
class FrequencyHashmap {
56
 public:
57
58
  static constexpr int64_t kDefaultEdgeTableScale = 3;
  FrequencyHashmap() = delete;
59
60
  FrequencyHashmap(
      int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
sangwzh's avatar
sangwzh committed
61
      hipStream_t stream, int64_t edge_table_scale = kDefaultEdgeTableScale);
62
63
64
  ~FrequencyHashmap();
  using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
  std::tuple<IdArray, IdArray, IdArray> Topk(
65
      const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
66
67
      const int64_t num_edges, const int64_t num_edges_per_node,
      const int64_t num_pick);
68

69
 private:
70
  DGLContext _ctx;
sangwzh's avatar
sangwzh committed
71
  hipStream_t _stream;
72
  DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
73
  IdxType *_dst_unique_edges;
74
75
76
77
78
79
80
  EdgeItem *_edge_hashmap;
};

};  // namespace impl
};  // namespace sampling
};  // namespace dgl

81
#endif  // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_HASHMAP_CUH_