hip_compat.cuh 2.63 KB
Newer Older
fxmarty's avatar
fxmarty committed
1
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
Nicolas Patry's avatar
Nicolas Patry committed
2

fxmarty's avatar
fxmarty committed
3
4
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
Nicolas Patry's avatar
Nicolas Patry committed
5

fxmarty's avatar
fxmarty committed
6
7
8
9
10
11
12
13
14
15
16
17
18
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
    return __half_raw{
        static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}

__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
    return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
        static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}

#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
Nicolas Patry's avatar
Nicolas Patry committed
19

fxmarty's avatar
fxmarty committed
20
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
Nicolas Patry's avatar
Nicolas Patry committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
                                                               hipblasOperation_t transA,
                                                               hipblasOperation_t transB,
                                                               int                m,
                                                               int                n,
                                                               int                k,
                                                               const half*        alpha,
                                                               const half*        AP,
                                                               int                lda,
                                                               const half*        BP,
                                                               int                ldb,
                                                               const half*        beta,
                                                               half*              CP,
                                                               int                ldc) {
    return hipblasHgemm(handle, transA, transB, m, n, k,
                        reinterpret_cast<const hipblasHalf *>(alpha),
                        reinterpret_cast<const hipblasHalf *>(AP), lda,
                        reinterpret_cast<const hipblasHalf *>(BP), ldb,
                        reinterpret_cast<const hipblasHalf *>(beta),
                        reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm

// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
fxmarty's avatar
fxmarty committed
45
#define rocblas_handle hipblasHandle_t
Nicolas Patry's avatar
Nicolas Patry committed
46
#define rocblas_operation_none HIPBLAS_OP_N
fxmarty's avatar
fxmarty committed
47
48
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
Nicolas Patry's avatar
Nicolas Patry committed
49
50
#define rocblas_hgemm __compat_hipblasHgemm

OlivierDehaene's avatar
OlivierDehaene committed
51
#endif