random.h 1.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

/**
 *  Copyright (c) 2023 by Contributors
 *
 * @file random.h
 * @brief Random Engine class.
 */
#ifndef GRAPHBOLT_RANDOM_H_
#define GRAPHBOLT_RANDOM_H_

#include <dmlc/thread_local.h>

#include <pcg_random.hpp>
#include <random>
#include <thread>

namespace graphbolt {

namespace {

// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
  static int num_threads = 0;
  static std::mutex mutex;
  static thread_local int id = -1;

  if (id == -1) {
    std::lock_guard<std::mutex> guard(mutex);
    id = num_threads;
    num_threads++;
  }
  return id;
}

};  // namespace

/**
 * @brief Thread-local Random Number Generator class.
 */
class RandomEngine {
 public:
  /** @brief Constructor with default seed. */
  RandomEngine() {
    std::random_device rd;
    SetSeed(rd());
  }

  /** @brief Constructor with given seed. */
  explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {
    SetSeed(seed, stream);
  }

  /** @brief Get the thread-local random number generator instance. */
  static RandomEngine* ThreadLocal() {
    return dmlc::ThreadLocalStore<RandomEngine>::Get();
  }

  /** @brief Set the seed. */
  void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {
    rng_.seed(seed, stream);
  }

  /**
   * @brief Generate a uniform random integer in [low, high).
   */
  template <typename T>
  T RandInt(T lower, T upper) {
    std::uniform_int_distribution<T> dist(lower, upper - 1);
    return dist(rng_);
  }

 private:
  pcg32 rng_;
};
}  // namespace graphbolt

#endif  // GRAPHBOLT_RANDOM_H_