hip_compat.cuh 2.69 KB
Newer Older
ilyas@huggingface.co's avatar
ilyas@huggingface.co committed
1
2
3
4
5
6
7
8
9
10
11
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _hip_compat_cuh
#define _hip_compat_cuh

// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
    return __half_raw{
        static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}

Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
12
// ROCm 6.0 compatible from: /opt/rocm-6.0.0/include/hip/amd_detail/amd_hip_fp16.h:1708
ilyas@huggingface.co's avatar
ilyas@huggingface.co committed
13
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
Ilyas Moutawwakil's avatar
Ilyas Moutawwakil committed
14
    return _Float16_2{_Float16_2{static_cast<_Float16>(1.0f), static_cast<_Float16>(1.0f)} / x.data};
ilyas@huggingface.co's avatar
ilyas@huggingface.co committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
}

#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp

// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
                                                               hipblasOperation_t transA,
                                                               hipblasOperation_t transB,
                                                               int                m,
                                                               int                n,
                                                               int                k,
                                                               const half*        alpha,
                                                               const half*        AP,
                                                               int                lda,
                                                               const half*        BP,
                                                               int                ldb,
                                                               const half*        beta,
                                                               half*              CP,
                                                               int                ldc) {
    return hipblasHgemm(handle, transA, transB, m, n, k,
                        reinterpret_cast<const hipblasHalf *>(alpha),
                        reinterpret_cast<const hipblasHalf *>(AP), lda,
                        reinterpret_cast<const hipblasHalf *>(BP), ldb,
                        reinterpret_cast<const hipblasHalf *>(beta),
                        reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define hipblasHgemm __compat_hipblasHgemm

// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm

#endif