util.h 2.01 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
#pragma once
#define DIV_CELL(a, b) (((a) + (b) - 1) / (b))
#if __cplusplus >= 201703L
    #define IF_CONSTEXPR constexpr
#else
    #define IF_CONSTEXPR
#endif

sangwz's avatar
sangwz committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// swz
#ifdef __HIP_PLATFORM_HCC__
#include<hip/hip_bfloat16.h>
#if defined(__HIPCC_RTC__)
#define __HOST_DEVICE__ __device__
#else
#define __HOST_DEVICE__ __host__ __device__
// TODO: Clang has a bug which allows device functions to call std functions
// when std functions are introduced into default namespace by using statement.
// math.h may be included after this bug is fixed.
#if __cplusplus
#include <cmath>
#else
#include "math.h"
#endif
#endif // !defined(__HIPCC_RTC__)
struct hip_bfloat162
{
   hip_bfloat16 x;
   hip_bfloat16 y;
public:
    __HOST_DEVICE__
    hip_bfloat162() = default;
    __HOST_DEVICE__
    hip_bfloat162(const hip_bfloat16& in1, const hip_bfloat16& in2):x{in1},y{in2}
    {}
    __HOST_DEVICE__
    hip_bfloat162& operator =(const hip_bfloat162& x)
    {
        this->x = hip_bfloat16(float(x.x));
        this->y = hip_bfloat16(float(x.y));
        return *this;
    }

};
#endif

Guolin Ke's avatar
Guolin Ke committed
46
47
48
template <typename T>
__device__ __forceinline__ T SHFL_XOR(T value, int laneMask, int width, unsigned int mask = 0xffffffff)
{
sangwz's avatar
sangwz committed
49
#if CUDA_VERSION >= 9000&& !defined(__HIP_PLATFORM_HCC__)
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename T, int N>
struct VecTypeImpl;

#define DEFINE_VEC_TYPE(t, n, tn) \
template <> \
struct VecTypeImpl<t, n> { \
    using type = tn; \
};

DEFINE_VEC_TYPE(half, 1, half)
DEFINE_VEC_TYPE(__nv_bfloat16, 1, __nv_bfloat16)
DEFINE_VEC_TYPE(float, 1, float)
DEFINE_VEC_TYPE(half, 2, half2)
sangwz's avatar
sangwz committed
69
70
71
#ifdef __HIP_PLATFORM_HCC__
DEFINE_VEC_TYPE(__nv_bfloat16, 2, hip_bfloat162)
#else
Guolin Ke's avatar
Guolin Ke committed
72
DEFINE_VEC_TYPE(__nv_bfloat16, 2, __nv_bfloat162)
sangwz's avatar
sangwz committed
73
#endif
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
80
DEFINE_VEC_TYPE(float, 2, float2)
DEFINE_VEC_TYPE(half, 4, uint64_t)
DEFINE_VEC_TYPE(__nv_bfloat16, 4, uint64_t)
DEFINE_VEC_TYPE(float, 4, float4)

template <typename T, int N>
using VecType = typename VecTypeImpl<T, N>::type;