math.hpp 4.62 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

Chao Liu's avatar
Chao Liu committed
4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
#include "ck/ck.hpp"
Chao Liu's avatar
Chao Liu committed
7
#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
8
#include "number.hpp"
Chao Liu's avatar
Chao Liu committed
9
#include "type.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "enable_if.hpp"
Chao Liu's avatar
Chao Liu committed
11

carlushuang's avatar
carlushuang committed
12
13
14
15
#ifndef CK_NOCPU
#include <math.h>
#endif

Chao Liu's avatar
Chao Liu committed
16
17
18
namespace ck {
namespace math {

zjing14's avatar
zjing14 committed
19
template <typename T, T s>
Chao Liu's avatar
Chao Liu committed
20
21
22
23
24
struct scales
{
    __host__ __device__ constexpr T operator()(T a) const { return s * a; }
};

zjing14's avatar
zjing14 committed
25
template <typename T>
Chao Liu's avatar
Chao Liu committed
26
27
28
29
30
struct plus
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};

zjing14's avatar
zjing14 committed
31
template <typename T>
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
struct minus
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};

struct multiplies
Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
43
44
45
{
    template <typename A, typename B>
    __host__ __device__ constexpr auto operator()(const A& a, const B& b) const
    {
        return a * b;
    }
};

zjing14's avatar
zjing14 committed
46
template <typename T>
Chao Liu's avatar
Chao Liu committed
47
struct maximize
Chao Liu's avatar
Chao Liu committed
48
49
50
51
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};

zjing14's avatar
zjing14 committed
52
template <typename T>
Chao Liu's avatar
Chao Liu committed
53
54
55
56
57
struct minimize
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};

zjing14's avatar
zjing14 committed
58
template <typename T>
Chao Liu's avatar
Chao Liu committed
59
60
61
62
63
64
struct integer_divide_ceiler
{
    __host__ __device__ constexpr T operator()(T a, T b) const
    {
        static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");

zjing14's avatar
zjing14 committed
65
        return (a + b - Number<1>{}) / b;
Chao Liu's avatar
Chao Liu committed
66
67
68
    }
};

zjing14's avatar
zjing14 committed
69
template <typename X, typename Y>
70
71
72
73
74
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
    return x / y;
}

zjing14's avatar
zjing14 committed
75
template <typename X, typename Y>
Chao Liu's avatar
Chao Liu committed
76
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Chao Liu's avatar
Chao Liu committed
77
{
78
    return (x + y - Number<1>{}) / y;
Chao Liu's avatar
Chao Liu committed
79
80
}

zjing14's avatar
zjing14 committed
81
template <typename X, typename Y>
Chao Liu's avatar
Chao Liu committed
82
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Chao Liu's avatar
Chao Liu committed
83
{
Chao Liu's avatar
Chao Liu committed
84
    return y * integer_divide_ceil(x, y);
Chao Liu's avatar
Chao Liu committed
85
86
}

zjing14's avatar
zjing14 committed
87
template <typename T>
Chao Liu's avatar
Chao Liu committed
88
89
90
91
92
__host__ __device__ constexpr T max(T x)
{
    return x;
}

zjing14's avatar
zjing14 committed
93
94
template <typename T>
__host__ __device__ constexpr T max(T x, T y)
Chao Liu's avatar
Chao Liu committed
95
{
zjing14's avatar
zjing14 committed
96
97
    return x > y ? x : y;
}
Chao Liu's avatar
Chao Liu committed
98

zjing14's avatar
zjing14 committed
99
100
101
102
103
template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
{
    return X > y ? X : y;
}
Chao Liu's avatar
Chao Liu committed
104

zjing14's avatar
zjing14 committed
105
106
107
108
109
template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
{
    return x > Y ? x : Y;
}
Chao Liu's avatar
Chao Liu committed
110

zjing14's avatar
zjing14 committed
111
112
113
114
115
116
template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys)
{
    static_assert(sizeof...(Ys) > 0, "not enough argument");

    return max(x, max(ys...));
Chao Liu's avatar
Chao Liu committed
117
118
}

zjing14's avatar
zjing14 committed
119
template <typename T>
Chao Liu's avatar
Chao Liu committed
120
121
122
123
124
__host__ __device__ constexpr T min(T x)
{
    return x;
}

zjing14's avatar
zjing14 committed
125
126
127
128
129
130
131
132
template <typename T>
__host__ __device__ constexpr T min(T x, T y)
{
    return x < y ? x : y;
}

template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
Chao Liu's avatar
Chao Liu committed
133
{
zjing14's avatar
zjing14 committed
134
135
    return X < y ? X : y;
}
Chao Liu's avatar
Chao Liu committed
136

zjing14's avatar
zjing14 committed
137
138
139
140
141
template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
{
    return x < Y ? x : Y;
}
Chao Liu's avatar
Chao Liu committed
142

zjing14's avatar
zjing14 committed
143
144
145
146
template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys)
{
    static_assert(sizeof...(Ys) > 0, "not enough argument");
Chao Liu's avatar
Chao Liu committed
147

zjing14's avatar
zjing14 committed
148
    return min(x, min(ys...));
Chao Liu's avatar
Chao Liu committed
149
150
}

carlushuang's avatar
carlushuang committed
151
#ifndef CK_NOGPU
152
153
154
155
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);

156
157
// TODO: add f16 support using v_exp_f16

158
159
160
161
162
163
164
165
166
167
168
template <>
__device__ float exp<float>(float x)
{
    return __expf(x);
}

template <>
__device__ double exp<double>(double x)
{
    return exp(x);
}
carlushuang's avatar
carlushuang committed
169
#endif
170

171
// greatest common divisor, aka highest common factor
Chao Liu's avatar
Chao Liu committed
172
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Chao Liu's avatar
Chao Liu committed
173
{
174
175
176
177
178
179
180
181
182
    if(x < 0)
    {
        return gcd(-x, y);
    }
    else if(y < 0)
    {
        return gcd(x, -y);
    }
    else if(x == y || x == 0)
Chao Liu's avatar
Chao Liu committed
183
184
185
    {
        return y;
    }
Chao Liu's avatar
Chao Liu committed
186
    else if(y == 0)
Chao Liu's avatar
Chao Liu committed
187
188
189
    {
        return x;
    }
Chao Liu's avatar
Chao Liu committed
190
    else if(x > y)
Chao Liu's avatar
Chao Liu committed
191
    {
192
        return gcd(x % y, y);
Chao Liu's avatar
Chao Liu committed
193
    }
Chao Liu's avatar
Chao Liu committed
194
    else
Chao Liu's avatar
Chao Liu committed
195
    {
196
        return gcd(x, y % x);
Chao Liu's avatar
Chao Liu committed
197
198
199
200
    }
}

template <index_t X, index_t Y>
201
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
Chao Liu's avatar
Chao Liu committed
202
{
Chao Liu's avatar
Chao Liu committed
203
204
205
    constexpr auto r = gcd(X, Y);

    return Number<r>{};
Chao Liu's avatar
Chao Liu committed
206
207
}

Chao Liu's avatar
Chao Liu committed
208
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
209
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
Chao Liu's avatar
Chao Liu committed
210
{
211
    return gcd(x, gcd(ys...));
Chao Liu's avatar
Chao Liu committed
212
213
214
}

// least common multiple
Chao Liu's avatar
Chao Liu committed
215
216
template <typename X, typename Y>
__host__ __device__ constexpr auto lcm(X x, Y y)
Chao Liu's avatar
Chao Liu committed
217
{
218
    return (x * y) / gcd(x, y);
Chao Liu's avatar
Chao Liu committed
219
220
}

Chao Liu's avatar
Chao Liu committed
221
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
Chao Liu's avatar
Chao Liu committed
222
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
Chao Liu's avatar
Chao Liu committed
223
{
Chao Liu's avatar
Chao Liu committed
224
    return lcm(x, lcm(ys...));
Chao Liu's avatar
Chao Liu committed
225
226
}

zjing14's avatar
zjing14 committed
227
template <typename T>
Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
struct equal
{
    __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};

zjing14's avatar
zjing14 committed
233
template <typename T>
Chao Liu's avatar
Chao Liu committed
234
235
236
237
238
struct less
{
    __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};

Chao Liu's avatar
Chao Liu committed
239
} // namespace math
Chao Liu's avatar
Chao Liu committed
240
} // namespace ck