math.hpp 2.02 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#ifndef CK_MATH_HPP
#define CK_MATH_HPP

#include "config.hpp"
Chao Liu's avatar
Chao Liu committed
5
#include "integral_constant.hpp"
Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

namespace ck {
namespace math {

template <class T, T s>
struct scales
{
    __host__ __device__ constexpr T operator()(T a) const { return s * a; }
};

template <class T>
struct plus
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};

template <class T>
struct minus
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};

template <class T>
struct multiplies
{
    __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};

template <class T>
struct integer_divide_ceiler
{
    __host__ __device__ constexpr T operator()(T a, T b) const
    {
        static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");

        return (a + b - 1) / b;
    }
};

Chao Liu's avatar
Chao Liu committed
45
46
template <class X, class Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Chao Liu's avatar
Chao Liu committed
47
{
Chao Liu's avatar
Chao Liu committed
48
    return (x + y - 1) / y;
Chao Liu's avatar
Chao Liu committed
49
50
}

Chao Liu's avatar
Chao Liu committed
51
52
template <class X, class Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Chao Liu's avatar
Chao Liu committed
53
{
Chao Liu's avatar
Chao Liu committed
54
    return y * integer_divide_ceil(x, y);
Chao Liu's avatar
Chao Liu committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
}

template <class T>
__host__ __device__ constexpr T max(T x)
{
    return x;
}

template <class T, class... Ts>
__host__ __device__ constexpr T max(T x, Ts... xs)
{
    static_assert(sizeof...(xs) > 0, "not enough argument");

    auto y = max(xs...);

    static_assert(is_same<decltype(y), T>{}, "not the same type");

    return x > y ? x : y;
}

template <class T>
__host__ __device__ constexpr T min(T x)
{
    return x;
}

template <class T, class... Ts>
__host__ __device__ constexpr T min(T x, Ts... xs)
{
    static_assert(sizeof...(xs) > 0, "not enough argument");

    auto y = min(xs...);

    static_assert(is_same<decltype(y), T>{}, "not the same type");

    return x < y ? x : y;
}

// this is WRONG
// TODO: implement least common multiple properly, instead of calling max()
template <class T, class... Ts>
__host__ __device__ constexpr T lcm(T x, Ts... xs)
{
    return max(x, xs...);
}

} // namespace math
} // namspace ck

#endif