"vscode:/vscode.git/clone" did not exist on "83fef0a57681b06ab03dd7302e76e7a7d21961dc"
continuous_seed.h 2.69 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/*!
 *   Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 *   All rights reserved.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 *
 * @file dgl/continuous_seed.h
 * @brief CPU and CUDA implementation for continuous random seeds
 */
#ifndef DGL_RANDOM_CONTINUOUS_SEED_H_
#define DGL_RANDOM_CONTINUOUS_SEED_H_

#include <dgl/array.h>

#include <cmath>

sangwzh's avatar
sangwzh committed
28
29
#ifdef __HIPCC__
#include <hiprand/hiprand_kernel.h>
30
31
32
33
#else
#include <random>

#include "pcg_random.hpp"
sangwzh's avatar
sangwzh committed
34
#endif  // __HIP_DEVICE_COMPILE__
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

#ifndef M_SQRT1_2
#define M_SQRT1_2 0.707106781186547524401
#endif  // M_SQRT1_2

namespace dgl {
namespace random {

class continuous_seed {
  uint64_t s[2];
  float c[2];

 public:
  /* implicit */ continuous_seed(const int64_t seed) {  // NOLINT
    s[0] = s[1] = seed;
    c[0] = c[1] = 0;
  }

  continuous_seed(IdArray seed_arr, float r) {
    auto seed = seed_arr.Ptr<int64_t>();
    s[0] = seed[0];
    s[1] = seed[seed_arr->shape[0] - 1];
    const auto pi = std::acos(-1.0);
    c[0] = std::cos(pi * r / 2);
    c[1] = std::sin(pi * r / 2);
  }

sangwzh's avatar
sangwzh committed
62
#ifdef __HIP_DEVICE_COMPILE__
63
64
  __device__ inline float uniform(const uint64_t t) const {
    const uint64_t kCurandSeed = 999961;  // Could be any random number.
sangwzh's avatar
sangwzh committed
65
66
    hiprandStatePhilox4_32_10_t rng;
    hiprand_init(kCurandSeed, s[0], t, &rng);
67
68
    float rnd;
    if (s[0] != s[1]) {
sangwzh's avatar
sangwzh committed
69
70
71
      rnd = c[0] * hiprand_normal(&rng);
      hiprand_init(kCurandSeed, s[1], t, &rng);
      rnd += c[1] * hiprand_normal(&rng);
72
73
      rnd = normcdff(rnd);
    } else {
sangwzh's avatar
sangwzh committed
74
      rnd = hiprand_uniform(&rng);
75
76
77
78
    }
    return rnd;
  }
#else
sangwzh's avatar
sangwzh committed
79
  __host__ inline float uniform(const uint64_t t) const {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    pcg32 ng0(s[0], t);
    float rnd;
    if (s[0] != s[1]) {
      std::normal_distribution<float> norm;
      rnd = c[0] * norm(ng0);
      pcg32 ng1(s[1], t);
      norm.reset();
      rnd += c[1] * norm(ng1);
      rnd = std::erfc(-rnd * static_cast<float>(M_SQRT1_2)) / 2.0f;
    } else {
      std::uniform_real_distribution<float> uni;
      rnd = uni(ng0);
    }
    return rnd;
  }
sangwzh's avatar
sangwzh committed
95
#endif  // __HIP_DEVICE_COMPILE__
96
97
98
99
100
101
};

}  // namespace random
}  // namespace dgl

#endif  // DGL_RANDOM_CONTINUOUS_SEED_H_