random.cc 2.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file random.cc
 * \brief Random number generator interfaces
 */

#include <dmlc/omp.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/random.h>
11
#include <dgl/array.h>
12

13
14
15
16
#ifdef DGL_USE_CUDA
#include "../runtime/cuda/cuda_common.h"
#endif  // DGL_USE_CUDA

17
18
19
20
21
22
using namespace dgl::runtime;

namespace dgl {

DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
23
    const int seed = args[0];
24
#pragma omp parallel for
25
    for (int i = 0; i < omp_get_max_threads(); ++i) {
26
      RandomEngine::ThreadLocal()->SetSeed(seed);
27
28
29
30
31
32
33
34
35
36
#ifdef DGL_USE_CUDA
      auto* thr_entry = CUDAThreadEntry::ThreadLocal();
      if (!thr_entry->curand_gen) {
        CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
      }
      CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
          thr_entry->curand_gen,
          static_cast<uint64_t>(seed + GetThreadId())));
#endif  // DGL_USE_CUDA
    }
37
38
  });

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    const int64_t num = args[0];
    const int64_t population = args[1];
    const NDArray prob = args[2];
    const bool replace = args[3];
    const int bits = args[4];
    CHECK(bits == 32 || bits == 64)
      << "Supported bit widths are 32 and 64, but got " << bits << ".";
    if (aten::IsNullArray(prob)) {
      if (bits == 32) {
        *rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(num, population, replace);
      } else {
        *rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(num, population, replace);
      }
    } else {
      if (bits == 32) {
        ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
          *rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(num, prob, replace);
        });
      } else {
        ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
          *rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(num, prob, replace);
        });
      }
    }
  });

67
};  // namespace dgl