random.h 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12

/**
 *  Copyright (c) 2023 by Contributors
 *
 * @file random.h
 * @brief Random Engine class.
 */
#ifndef GRAPHBOLT_RANDOM_H_
#define GRAPHBOLT_RANDOM_H_

#include <dmlc/thread_local.h>

13
#include <mutex>
14
#include <optional>
15
16
17
18
19
20
21
22
23
24
25
26
#include <pcg_random.hpp>
#include <random>
#include <thread>

namespace graphbolt {

/**
 * @brief Thread-local Random Number Generator class.
 */
class RandomEngine {
 public:
  /** @brief Constructor with default seed. */
27
  RandomEngine();
28
29

  /** @brief Constructor with given seed. */
30
31
  explicit RandomEngine(uint64_t seed);
  explicit RandomEngine(uint64_t seed, uint64_t stream);
32
33

  /** @brief Get the thread-local random number generator instance. */
34
  static RandomEngine* ThreadLocal();
35
36

  /** @brief Set the seed. */
37
38
39
  void SetSeed(uint64_t seed);
  void SetSeed(uint64_t seed, uint64_t stream);

40
41
42
  /** @brief Protect manual seed accesses. */
  static std::mutex manual_seed_mutex;

43
44
45
  /** @brief Manually fix the seed. */
  static std::optional<uint64_t> manual_seed;
  static void SetManualSeed(int64_t seed);
46
47
48
49
50
51
52
53
54
55

  /**
   * @brief Generate a uniform random integer in [low, high).
   */
  template <typename T>
  T RandInt(T lower, T upper) {
    std::uniform_int_distribution<T> dist(lower, upper - 1);
    return dist(rng_);
  }

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  /**
   * @brief Generate a uniform random real number in [low, high).
   */
  template <typename T>
  T Uniform(T lower, T upper) {
    std::uniform_real_distribution<T> dist(lower, upper);
    return dist(rng_);
  }

  /**
   * @brief Generate random non-negative floating-point values according to
   * exponential distribution. Probability density function: P(x|λ) = λe^(-λx).
   */
  template <typename T>
  T Exponential(T lambda) {
    std::exponential_distribution<T> dist(lambda);
    return dist(rng_);
  }

75
76
77
 private:
  pcg32 rng_;
};
78

79
80
81
}  // namespace graphbolt

#endif  // GRAPHBOLT_RANDOM_H_