random.cc 1.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
 *  Copyright (c) 2023 by Contributors
 * @file random.cc
 * @brief Random Engine.
 */

#include "./random.h"

#include <torch/torch.h>

namespace graphbolt {

namespace {

// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() {
  static int num_threads = 0;
  static std::mutex mutex;
  static thread_local int id = -1;

  if (id == -1) {
    std::lock_guard<std::mutex> guard(mutex);
    id = num_threads;
    num_threads++;
  }
  return id;
}

};  // namespace

std::optional<uint64_t> RandomEngine::manual_seed;

/** @brief Constructor with default seed. */
RandomEngine::RandomEngine() {
  std::random_device rd;
  uint64_t seed = manual_seed.value_or(rd());
  SetSeed(seed);
}

/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed) { RandomEngine(seed, GetThreadId()); }

/** @brief Constructor with given seed. */
RandomEngine::RandomEngine(uint64_t seed, uint64_t stream) {
  SetSeed(seed, stream);
}

/** @brief Get the thread-local random number generator instance. */
RandomEngine* RandomEngine::ThreadLocal() {
  return dmlc::ThreadLocalStore<RandomEngine>::Get();
}

/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed) { SetSeed(seed, GetThreadId()); }

/** @brief Set the seed. */
void RandomEngine::SetSeed(uint64_t seed, uint64_t stream) {
  rng_.seed(seed, stream);
}

/** @brief Manually fix the seed. */
void RandomEngine::SetManualSeed(int64_t seed) { manual_seed = seed; }

}  // namespace graphbolt