Unverified Commit b1e2695f authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Feature] replace dgl PRNG with pcg32 (#4807)



* replace dgl PRNG with pcg32

* remove pcg submodule, add a simple implementation

* replace pcg32 with std::mt19937_64

* fix include order

* change RandomEngine to pcg32

* Remove custom pcg32 implementation, use the submodule provided by the original author.

* minor bug

* move include for linting

* include pcg for tests too
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent adb07d18
...@@ -335,6 +335,7 @@ if(BUILD_CPP_TEST) ...@@ -335,6 +335,7 @@ if(BUILD_CPP_TEST)
include_directories("third_party/dmlc-core/include") include_directories("third_party/dmlc-core/include")
include_directories("third_party/phmap") include_directories("third_party/phmap")
include_directories("third_party/libxsmm/include") include_directories("third_party/libxsmm/include")
include_directories("third_party/pcg/include")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc) file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES}) add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main) target_link_libraries(runUnitTests gtest gtest_main)
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <pcg_random.hpp>
namespace dgl { namespace dgl {
namespace { namespace {
...@@ -47,7 +49,9 @@ class RandomEngine { ...@@ -47,7 +49,9 @@ class RandomEngine {
} }
/** @brief Constructor with given seed */ /** @brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) { SetSeed(seed); } explicit RandomEngine(uint64_t seed, uint64_t stream = GetThreadId()) {
SetSeed(seed, stream);
}
/** @brief Get the thread-local random number generator instance */ /** @brief Get the thread-local random number generator instance */
static RandomEngine* ThreadLocal() { static RandomEngine* ThreadLocal() {
...@@ -57,9 +61,8 @@ class RandomEngine { ...@@ -57,9 +61,8 @@ class RandomEngine {
/** /**
* @brief Set the seed of this random number generator * @brief Set the seed of this random number generator
*/ */
void SetSeed(uint32_t seed) { void SetSeed(uint64_t seed, uint64_t stream = GetThreadId()) {
std::seed_seq seq{seed, GetThreadId()}; rng_.seed(seed, stream);
rng_.seed(seq);
} }
/** /**
...@@ -249,7 +252,7 @@ class RandomEngine { ...@@ -249,7 +252,7 @@ class RandomEngine {
} }
private: private:
std::default_random_engine rng_; pcg32 rng_;
}; };
}; // namespace dgl }; // namespace dgl
......
/*! /**
* Copyright (c) 2022, NVIDIA Corporation * Copyright (c) 2022, NVIDIA Corporation
* Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved. * All rights reserved.
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* \file array/cpu/labor_pick.h * @file array/cpu/labor_pick.h
* \brief Template implementation for layerwise pick operators. * @brief Template implementation for layerwise pick operators.
*/ */
#ifndef DGL_ARRAY_CPU_LABOR_PICK_H_ #ifndef DGL_ARRAY_CPU_LABOR_PICK_H_
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <pcg_random.hpp>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment