/*! * Copyright (c) 2017 by Contributors * \file random.cc * \brief Random number generator interfaces */ #include #include #include #include #include #include #ifdef DGL_USE_CUDA #include "../runtime/cuda/cuda_common.h" #endif // DGL_USE_CUDA using namespace dgl::runtime; namespace dgl { DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") .set_body([] (DGLArgs args, DGLRetValue *rv) { const int seed = args[0]; runtime::parallel_for(0, omp_get_max_threads(), [&](size_t b, size_t e) { for (auto i = b; i < e; ++i) { RandomEngine::ThreadLocal()->SetSeed(seed); } }); #ifdef DGL_USE_CUDA if (DeviceAPI::Get(kDLROCM)->IsAvailable()) { auto* thr_entry = CUDAThreadEntry::ThreadLocal(); if (!thr_entry->curand_gen) { CURAND_CALL(hiprandCreateGenerator(&thr_entry->curand_gen, HIPRAND_RNG_PSEUDO_DEFAULT)); } CURAND_CALL(hiprandSetPseudoRandomGeneratorSeed( thr_entry->curand_gen, static_cast(seed))); } #endif // DGL_USE_CUDA }); DGL_REGISTER_GLOBAL("rng._CAPI_Choice") .set_body([] (DGLArgs args, DGLRetValue *rv) { const int64_t num = args[0]; const int64_t population = args[1]; const NDArray prob = args[2]; const bool replace = args[3]; const int bits = args[4]; CHECK(bits == 32 || bits == 64) << "Supported bit widths are 32 and 64, but got " << bits << "."; if (aten::IsNullArray(prob)) { if (bits == 32) { *rv = RandomEngine::ThreadLocal()->UniformChoice(num, population, replace); } else { *rv = RandomEngine::ThreadLocal()->UniformChoice(num, population, replace); } } else { if (bits == 32) { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { *rv = RandomEngine::ThreadLocal()->Choice(num, prob, replace); }); } else { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", { *rv = RandomEngine::ThreadLocal()->Choice(num, prob, replace); }); } } }); }; // namespace dgl