array_utils.h 3.79 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file dgl/array_utils.h
 * @brief Utility classes and functions for DGL arrays.
5
6
7
8
 */
#ifndef DGL_ARRAY_CPU_ARRAY_UTILS_H_
#define DGL_ARRAY_CPU_ARRAY_UTILS_H_

9
10
#include <dgl/aten/types.h>
#include <parallel_hashmap/phmap.h>
11

12
13
#include <unordered_map>
#include <utility>
14
15
#include <vector>

16
17
#include "../../c_api_common.h"

18
19
20
namespace dgl {
namespace aten {

21
/**
22
 * @brief A hashmap that maps each ids in the given array to new ids starting
23
 * from zero.
24
25
26
27
28
29
30
31
32
 *
 * Useful for relabeling integers and finding unique integers.
 *
 * Usually faster than std::unordered_map in existence checking.
 */
template <typename IdType>
class IdHashMap {
 public:
  // default ctor
33
  IdHashMap() : filter_(kFilterSize, false) {}
34
35
36

  // Construct the hashmap using the given id array.
  // The id array could contain duplicates.
37
38
39
  // If the id array has no duplicates, the array will be relabeled to
  // consecutive integers starting from 0.
  explicit IdHashMap(IdArray ids) : filter_(kFilterSize, false) {
40
    oldv2newv_.reserve(ids->shape[0]);
41
42
43
    Update(ids);
  }

44
  // copy ctor
45
  IdHashMap(const IdHashMap& other) = default;
46

47
  void Reserve(const int64_t size) { oldv2newv_.reserve(size); }
48

49
50
51
52
53
54
55
  // Update the hashmap with given id array.
  // The id array could contain duplicates.
  void Update(IdArray ids) {
    const IdType* ids_data = static_cast<IdType*>(ids->data);
    const int64_t len = ids->shape[0];
    for (int64_t i = 0; i < len; ++i) {
      const IdType id = ids_data[i];
56
57
      // phmap::flat_hash_map::insert assures that an insertion will not happen
      // if the key already exists.
58
59
      oldv2newv_.insert({id, oldv2newv_.size()});
      filter_[id & kFilterMask] = true;
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    }
  }

  // Return true if the given id is contained in this hashmap.
  bool Contains(IdType id) const {
    return filter_[id & kFilterMask] && oldv2newv_.count(id);
  }

  // Return the new id of the given id. If the given id is not contained
  // in the hash map, returns the default_val instead.
  IdType Map(IdType id, IdType default_val) const {
    if (filter_[id & kFilterMask]) {
      auto it = oldv2newv_.find(id);
      return (it == oldv2newv_.end()) ? default_val : it->second;
    } else {
      return default_val;
    }
  }

  // Return the new id of each id in the given array.
  IdArray Map(IdArray ids, IdType default_val) const {
    const IdType* ids_data = static_cast<IdType*>(ids->data);
    const int64_t len = ids->shape[0];
    IdArray values = NewIdArray(len, ids->ctx, ids->dtype.bits);
    IdType* values_data = static_cast<IdType*>(values->data);
    for (int64_t i = 0; i < len; ++i)
      values_data[i] = Map(ids_data[i], default_val);
    return values;
  }

  // Return all the old ids collected so far, ordered by new id.
  IdArray Values() const {
92
93
    IdArray values = NewIdArray(
        oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
94
    IdType* values_data = static_cast<IdType*>(values->data);
95
    for (auto pair : oldv2newv_) values_data[pair.second] = pair.first;
96
97
98
    return values;
  }

99
  inline size_t Size() const { return oldv2newv_.size(); }
100

101
102
103
104
 private:
  static constexpr int32_t kFilterMask = 0xFFFFFF;
  static constexpr int32_t kFilterSize = kFilterMask + 1;
  // This bitmap is used as a bloom filter to remove some lookups.
105
106
  // Hashtable is very slow. Using bloom filter can significantly speed up
  // lookups.
107
108
  std::vector<bool> filter_;
  // The hashmap from old vid to new vid
109
  phmap::flat_hash_map<IdType, IdType> oldv2newv_;
110
111
};

112
/**
113
 * @brief Hash type for building maps/sets with pairs as keys.
114
115
116
 */
struct PairHash {
  template <class T1, class T2>
117
  std::size_t operator()(const std::pair<T1, T2>& pair) const {
118
119
120
121
    return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
  }
};

122
123
}  // namespace aten
}  // namespace dgl
124
125

#endif  // DGL_ARRAY_CPU_ARRAY_UTILS_H_