philox_rand.hpp 3.66 KB
Newer Older
guangzlu's avatar
guangzlu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

namespace ck {

class philox
{
    public:
    __device__ inline philox(unsigned long long seed,
                             unsigned long long subsequence,
                             unsigned long long offset)
        : h_seed(reinterpret_cast<const uint2&>(seed))
    {

        ull2* tmp = reinterpret_cast<ull2*>(&counter);
        tmp->x    = offset / 4;
        tmp->y    = subsequence;
    }

    __device__ inline uint4 get_philox_4x32()
    {

        uint4 counter_ = counter;
        uint2 key_     = h_seed;
// 7-round philox
#pragma unroll
        for(int i = 0; i < 6; i++)
        {
            counter_ = single_loop(counter_, key_);
            key_.x += kPhilox10A;
            key_.y += kPhilox10B;
        }
        uint4 output = single_loop(counter_, key_);
        incr();

        return output;
    }

    __device__ inline uint4 get_philox_4x32(const unsigned long long subsequence)
    {

        uint4 counter_ = counter;
        ull2* tmp      = reinterpret_cast<ull2*>(&counter_);
        tmp->y         = subsequence;

        uint2 key_ = h_seed;
// 7-round philox
#pragma unroll
        for(int i = 0; i < 6; i++)
        {
            counter_ = single_loop(counter_, key_);
            key_.x += kPhilox10A;
            key_.y += kPhilox10B;
        }
        uint4 output = single_loop(counter_, key_);
        return output;
    }

    __device__ void get_random_8x16(ushort* out)
    {
        uint4 tmp_ph;
        tmp_ph = get_philox_4x32();

        uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);

        out_tmp[0] = tmp_ph.x;
        out_tmp[1] = tmp_ph.y;
        out_tmp[2] = tmp_ph.z;
        out_tmp[3] = tmp_ph.w;
    }

    __device__ void get_random_8x16(ushort* out, const unsigned long long subsequence)
    {
        uint4 tmp_ph;
        tmp_ph = get_philox_4x32(subsequence);

        uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);

        out_tmp[0] = tmp_ph.x;
        out_tmp[1] = tmp_ph.y;
        out_tmp[2] = tmp_ph.z;
        out_tmp[3] = tmp_ph.w;
    }

guangzlu's avatar
guangzlu committed
87
88
89
90
91
92
93
94
95
96
97
    __device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
    {
        uint4 tmp_ph;
        tmp_ph = get_philox_4x32(subsequence);

        out[0] = static_cast<ushort>(tmp_ph.x);
        out[1] = static_cast<ushort>(tmp_ph.y);
        out[2] = static_cast<ushort>(tmp_ph.z);
        out[3] = static_cast<ushort>(tmp_ph.w);
    }

guangzlu's avatar
guangzlu committed
98
99
100
101
102
103
104
105
106
107
108
109
110
    private:
    struct ull2
    {
        uint64_t x;
        uint64_t y;
    };
    uint4 counter;
    const uint2 h_seed;

    __device__ uint4 incr(uint4 ctr)
    {

        uint4 res;
guangzlu's avatar
guangzlu committed
111
112
113
114
        res.x = ctr.x + 1;
        res.y = ctr.y;
        res.z = ctr.z;
        res.w = ctr.w;
guangzlu's avatar
guangzlu committed
115
116
117
118
119
120
121
122
        return res;
    }

    __device__ inline void incr() { counter = incr(counter); }

    __device__ uint2 u32_high_low_multi(const unsigned int a, const unsigned int b)
    {
        uint2* res;
guangzlu's avatar
guangzlu committed
123
124
125
        unsigned long long tmp;
        tmp = static_cast<unsigned long long>(a) * b;
        res = reinterpret_cast<uint2*>(&tmp);
guangzlu's avatar
guangzlu committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        return *res;
    }

    __device__ inline uint4 single_loop(const uint4 ctr, const uint2 i_key)
    {

        uint2 res0 = u32_high_low_multi(kPhiloxSA, ctr.x);
        uint2 res1 = u32_high_low_multi(kPhiloxSB, ctr.z);
        uint4 ret  = {res1.y ^ ctr.y ^ i_key.x, res1.x, res0.y ^ ctr.w ^ i_key.y, res0.x};
        return ret;
    }

    static const unsigned long kPhilox10A = 0x9E3779B9;
    static const unsigned long kPhilox10B = 0xBB67AE85;
    static const unsigned long kPhiloxSA  = 0xD2511F53;
    static const unsigned long kPhiloxSB  = 0xCD9E8D57;
};

} // namespace ck