philox.cuh 5.25 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
Tri Dao's avatar
Tri Dao committed
2
3
4
#pragma once
// Philox CUDA.

Tri Dao's avatar
Tri Dao committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
namespace flash {

struct ull2 {
    unsigned long long x;
    unsigned long long y;
};

inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
    uint2 *res;
    unsigned long long tmp;
    asm ("mul.wide.u32 %0, %1, %2;\n\t"
          : "=l"(tmp)
          : "r"(a), "r"(b));
    res = (uint2*)(&tmp);
    return *res;
}

inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
    constexpr unsigned long kPhiloxSA = 0xD2511F53;
    constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
    uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
    uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
    uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
    return ret;
}

inline __device__ uint4 philox(unsigned long long seed,
                               unsigned long long subsequence,
                               unsigned long long offset) {
    constexpr unsigned long kPhilox10A = 0x9E3779B9;
    constexpr unsigned long kPhilox10B = 0xBB67AE85;
    uint2 key = reinterpret_cast<uint2&>(seed);
    uint4 counter;
    ull2 *tmp = reinterpret_cast<ull2*>(&counter);
    tmp->x = offset;
    tmp->y = subsequence;
    #pragma unroll
    for (int i = 0; i < 6; i++) {
        counter = philox_single_round(counter, key);
        key.x += (kPhilox10A);
        key.y += (kPhilox10B);
    }
    uint4 output = philox_single_round(counter, key);
    return output;
}

} // namespace flash

Tri Dao's avatar
Tri Dao committed
53
54
55
56
57
58
59
namespace {

class Philox {
public:
  __device__ inline Philox(unsigned long long seed,
                           unsigned long long subsequence,
                           unsigned long long offset)
Tri Dao's avatar
Tri Dao committed
60
61
62
63
      : STATE(0)
      , seed_(seed)
      , offset_(offset)
      , key(reinterpret_cast<const uint2&>(seed)) {
Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
69
70
71
    //key.x = (unsigned int)seed;
    //key.y = (unsigned int)(seed >> 32);
    //counter = make_uint4(0, 0, 0, 0);
    //counter.z = (unsigned int)(subsequence);
    //counter.w = (unsigned int)(subsequence >> 32);
    //STATE = 0;
    //incr_n(offset / 4);

Tri Dao's avatar
Tri Dao committed
72
    // key = reinterpret_cast<const uint2&>(seed);
Tri Dao's avatar
Tri Dao committed
73
74
75
76
77
78
79
80
    ull2 * tmp = reinterpret_cast<ull2*>(&counter);
    tmp->x = offset / 4;
    tmp->y = subsequence;
    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
    //     printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
    // }
  }
  __device__ inline uint4 operator()() {
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    // // if (STATE == 0) {
    //   uint4 counter_ = counter;
    //   uint2 key_ = key;
    //   // 7-round philox
    //   #pragma unroll
    //   for (int i = 0; i < 6; i++) {
    //       counter_ = flash::philox_single_round(counter_, key_);
    //     key_.x += (kPhilox10A);
    //     key_.y += (kPhilox10B);
    //   }
    //   // output = philox_single_round(counter_, key_);
    //   uint4 output = flash::philox_single_round(counter_, key_);
    //   // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
    //   //     printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
    //   //     printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
    //   // }
    //   incr();
    // // }
    // // return a float4 directly
    // // unsigned long ret;
    // // switch(STATE) {
    // //  case 0: ret = output.x; break;
    // //  case 1: ret = output.y; break;
    // //  case 2: ret = output.z; break;
    // //  case 3: ret = output.w; break;
    // //}
    // // STATE = (STATE + 1) % 4;
    // return output;
      return flash::philox(seed_, offset_, offset_);
Tri Dao's avatar
Tri Dao committed
110
111
112
  }

private:
Tri Dao's avatar
Tri Dao committed
113
  unsigned long long offset_, seed_;
Tri Dao's avatar
Tri Dao committed
114
115
116
117
118
  struct ull2 {
      uint64_t x;
      uint64_t y;
  };
  uint4 counter;
Tri Dao's avatar
Tri Dao committed
119
  // uint4 output;
Tri Dao's avatar
Tri Dao committed
120
  const uint2 key;
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  unsigned int STATE;
  __device__ inline void incr_n(unsigned long long n) {
    unsigned int nlo = (unsigned int)(n);
    unsigned int nhi = (unsigned int)(n >> 32);
    counter.x += nlo;
    if (counter.x < nlo)
      nhi++;
    counter.y += nhi;
    if (nhi <= counter.y)
      return;
    if (++counter.z)
      return;
    ++counter.w;
  }
Tri Dao's avatar
Tri Dao committed
135

Tri Dao's avatar
Tri Dao committed
136
137
  __device__ uint4 incr128 (uint4 ctr)
  {
Tri Dao's avatar
Tri Dao committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    uint4 res;
    asm ("add.cc.u32      %0, %4, %8;\n\t"
         "addc.cc.u32     %1, %5, %9;\n\t"
         "addc.cc.u32     %2, %6, %10;\n\t"
         "addc.u32        %3, %7, %11;\n\t"
         : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
         : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
           "n"(1), "n"(0), "n"(0), "n"(0));
    return res;
  }

  __device__ inline void incr() {
    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
    //     printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
    // }
Tri Dao's avatar
Tri Dao committed
153
    counter = incr128(counter);
Tri Dao's avatar
Tri Dao committed
154
155
156
157
    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
    //     printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
    // }
  }
158

Tri Dao's avatar
Tri Dao committed
159
160
  static const unsigned long kPhilox10A = 0x9E3779B9;
  static const unsigned long kPhilox10B = 0xBB67AE85;
Tri Dao's avatar
Tri Dao committed
161
162
  // static const unsigned long kPhiloxSA = 0xD2511F53;
  // static const unsigned long kPhiloxSB = 0xCD9E8D57;
Tri Dao's avatar
Tri Dao committed
163
};
164

Tri Dao's avatar
Tri Dao committed
165
} // namespace