"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "17612de451244c4c169b1498f1333401dfd3106f"
array_utils.h 3.94 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file dgl/array_utils.h
 * @brief Utility classes and functions for DGL arrays.
5
6
7
8
 */
#ifndef DGL_ARRAY_CPU_ARRAY_UTILS_H_
#define DGL_ARRAY_CPU_ARRAY_UTILS_H_

9
#include <dgl/aten/types.h>
10
#include <dgl/runtime/parallel_for.h>
11
#include <parallel_hashmap/phmap.h>
12

13
14
#include <unordered_map>
#include <utility>
15
16
#include <vector>

17
18
#include "../../c_api_common.h"

19
20
21
namespace dgl {
namespace aten {

22
/**
23
 * @brief A hashmap that maps each ids in the given array to new ids starting
24
 * from zero.
25
26
27
28
29
30
31
32
33
 *
 * Useful for relabeling integers and finding unique integers.
 *
 * Usually faster than std::unordered_map in existence checking.
 */
template <typename IdType>
class IdHashMap {
 public:
  // default ctor
34
  IdHashMap() : filter_(kFilterSize, false) {}
35
36
37

  // Construct the hashmap using the given id array.
  // The id array could contain duplicates.
38
39
40
  // If the id array has no duplicates, the array will be relabeled to
  // consecutive integers starting from 0.
  explicit IdHashMap(IdArray ids) : filter_(kFilterSize, false) {
41
    oldv2newv_.reserve(ids->shape[0]);
42
43
44
    Update(ids);
  }

45
  // copy ctor
46
  IdHashMap(const IdHashMap& other) = default;
47

48
  void Reserve(const int64_t size) { oldv2newv_.reserve(size); }
49

50
51
52
53
54
55
56
  // Update the hashmap with given id array.
  // The id array could contain duplicates.
  void Update(IdArray ids) {
    const IdType* ids_data = static_cast<IdType*>(ids->data);
    const int64_t len = ids->shape[0];
    for (int64_t i = 0; i < len; ++i) {
      const IdType id = ids_data[i];
57
58
      // phmap::flat_hash_map::insert assures that an insertion will not happen
      // if the key already exists.
59
60
      oldv2newv_.insert({id, oldv2newv_.size()});
      filter_[id & kFilterMask] = true;
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    }
  }

  // Return true if the given id is contained in this hashmap.
  bool Contains(IdType id) const {
    return filter_[id & kFilterMask] && oldv2newv_.count(id);
  }

  // Return the new id of the given id. If the given id is not contained
  // in the hash map, returns the default_val instead.
  IdType Map(IdType id, IdType default_val) const {
    if (filter_[id & kFilterMask]) {
      auto it = oldv2newv_.find(id);
      return (it == oldv2newv_.end()) ? default_val : it->second;
    } else {
      return default_val;
    }
  }

  // Return the new id of each id in the given array.
  IdArray Map(IdArray ids, IdType default_val) const {
    const IdType* ids_data = static_cast<IdType*>(ids->data);
    const int64_t len = ids->shape[0];
    IdArray values = NewIdArray(len, ids->ctx, ids->dtype.bits);
85
86
87
88
89
90
91
    IdType* values_data = values.Ptr<IdType>();
    runtime::parallel_for(
        0, len, 1000, [=] (size_t begin, size_t end) {
          for (size_t i = begin; i < end; ++i) {
            values_data[i] = Map(ids_data[i], default_val);
          }
        });
92
93
94
95
96
    return values;
  }

  // Return all the old ids collected so far, ordered by new id.
  IdArray Values() const {
97
98
    IdArray values = NewIdArray(
        oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
99
    IdType* values_data = static_cast<IdType*>(values->data);
100
    for (auto pair : oldv2newv_) values_data[pair.second] = pair.first;
101
102
103
    return values;
  }

104
  inline size_t Size() const { return oldv2newv_.size(); }
105

106
107
108
109
 private:
  static constexpr int32_t kFilterMask = 0xFFFFFF;
  static constexpr int32_t kFilterSize = kFilterMask + 1;
  // This bitmap is used as a bloom filter to remove some lookups.
110
111
  // Hashtable is very slow. Using bloom filter can significantly speed up
  // lookups.
112
113
  std::vector<bool> filter_;
  // The hashmap from old vid to new vid
114
  phmap::flat_hash_map<IdType, IdType> oldv2newv_;
115
116
};

117
/**
118
 * @brief Hash type for building maps/sets with pairs as keys.
119
120
121
 */
struct PairHash {
  template <class T1, class T2>
122
  std::size_t operator()(const std::pair<T1, T2>& pair) const {
123
124
125
126
    return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
  }
};

127
128
}  // namespace aten
}  // namespace dgl
129
130

#endif  // DGL_ARRAY_CPU_ARRAY_UTILS_H_