magic_division.hpp 7.24 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
#pragma once
5

Chao Liu's avatar
Chao Liu committed
6
#include "ck/ck.hpp"
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"

namespace ck {

// magic number division
// Caution:
//   1. For uint32_t as dividend: magic number division implementation being used would produce
//   correct result if the dividend is uint32_t and its value is within 31-bit value range.
//   2. For int32_t as dividendd: magic number division for int32_t dividened has not been
//   implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
//   division implementation for uint32_t is then used. Therefore, dividend value need to be
//   non-negative.
// TODO:
//   1. Implement magic number divison for int32_t
//   2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision
{
    // uint32_t
    __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
    {
30
31
32
33
        // WARNING: magic division is only applicable for division inside this range.
        // You should use the return value of CalculateMagicNumbers, if division is not inside this
        // range. The "else" logic below is to quiet down run-time error.
        if(divisor >= 1 && divisor <= INT32_MAX)
34
        {
35
36
            uint32_t shift = 0;
            for(shift = 0; shift < 32; ++shift)
37
            {
38
39
40
41
                if((1U << shift) >= divisor)
                {
                    break;
                }
42
43
            }

44
45
46
            uint64_t one        = 1;
            uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
            // assert(multiplier <= 0xffffffffUL);
47

48
49
50
51
52
53
            return make_tuple(uint32_t(multiplier), shift);
        }
        else
        {
            return make_tuple(uint32_t(0), uint32_t(0));
        }
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    }

    __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
    {
        auto tmp = CalculateMagicNumbers(divisor);

        return tmp[Number<0>{}];
    }

    __host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
    {
        auto tmp = CalculateMagicNumbers(divisor);

        return tmp[Number<1>{}];
    }

    // integral_constant<uint32_t, .>
    template <uint32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
    {
        constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});

        constexpr uint32_t multiplier = tmp[Number<0>{}];
        constexpr uint32_t shift      = tmp[Number<1>{}];

        return make_tuple(integral_constant<uint32_t, multiplier>{},
                          integral_constant<uint32_t, shift>{});
    }

    template <uint32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
    {
        constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});

        return integral_constant<uint32_t, multiplier>{};
    }

    template <uint32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicShift(integral_constant<uint32_t, Divisor>)
    {
        constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});

        return integral_constant<uint32_t, shift>{};
    }

    // integral_constant<int32_t, .>
    template <int32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
    {
        return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
    }

    template <int32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
    {
        return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
    }

    template <int32_t Divisor>
    __host__ __device__ static constexpr auto
        CalculateMagicShift(integral_constant<int32_t, Divisor>)
    {
        return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
    }

    // magic division for uint32_t
Jianfeng Yan's avatar
Jianfeng Yan committed
125
    __device__ static constexpr uint32_t
126
127
    DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
    {
Chao Liu's avatar
Chao Liu committed
128
        uint32_t tmp = __umulhi(dividend, multiplier);
129
130
131
        return (tmp + dividend) >> shift;
    }

Jianfeng Yan's avatar
Jianfeng Yan committed
132
133
134
135
136
137
138
    __host__ static constexpr uint32_t
    DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
    {
        uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
        return (tmp + dividend) >> shift;
    }

Chao Liu's avatar
Chao Liu committed
139
    // magic division for int32_t
140
141
142
    // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
    // non-negative for result to be correct
    // TODO: figure out how to do magic number divison for int32_t as dividended
Jianfeng Yan's avatar
Jianfeng Yan committed
143
    __device__ static constexpr int32_t
144
145
    DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
    {
146
        uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
Chao Liu's avatar
Chao Liu committed
147
        uint32_t tmp          = __umulhi(dividend_u32, multiplier);
148
        return (tmp + dividend_u32) >> shift;
149
    }
Jianfeng Yan's avatar
Jianfeng Yan committed
150
151
152
153
154
155
156
157

    __host__ static constexpr int32_t
    DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
    {
        uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
        uint32_t tmp          = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
        return (tmp + dividend_u32) >> shift;
    }
158
159
};

160
161
162
163
164
struct MDiv
{
    // 1 dword -> 3 dword storage
    uint32_t divisor;
    uint32_t multiplier;
carlushuang's avatar
carlushuang committed
165
    uint32_t shift; // TODO: 8 bit is enough
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    // prefer construct on host
    __host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
    {
        ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
    }

    __host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {}

    __host__ __device__ void update(uint32_t divisor_)
    {
        divisor                    = divisor_;
        ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
    }

181
    __host__ __device__ uint32_t div(uint32_t dividend_) const
182
    {
183
        return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
184
185
186
    }

    __host__ __device__ void
187
    divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
188
    {
189
190
        quotient_  = div(dividend_);
        remainder_ = dividend_ - (quotient_ * divisor);
191
192
193
194
195
    }

    __host__ __device__ uint32_t get() const { return divisor; }
};

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
struct MDiv2
{
    // 1 dword -> 2 dword storage, divisor need compute from runtime
    uint32_t multiplier;
    uint32_t shift; // TODO: 8 bit is enough

    // prefer construct on host
    __host__ __device__ MDiv2(uint32_t divisor_)
    {
        ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
    }

    __host__ __device__ MDiv2() : multiplier(0), shift(0) {}

    __host__ __device__ uint32_t div(uint32_t dividend_) const
    {
        return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
    }

    __host__ __device__ void
    divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
    {
        quotient_  = div(dividend_);
        remainder_ = dividend_ - (quotient_ * divisor_);
    }
};

223
} // namespace ck