cublas_wrapper.h 8.15 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
#include <iostream>
Rick Ho's avatar
Rick Ho committed
2
3
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
zhanggzh's avatar
zhanggzh committed
4
5
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>

Rick Ho's avatar
Rick Ho committed
6
#include <cublas_v2.h>
Rick Ho's avatar
Rick Ho committed
7
#include <c10/util/Half.h>
Rick Ho's avatar
Rick Ho committed
8
#include <c10/util/BFloat16.h>
zhanggzh's avatar
zhanggzh committed
9
10
#include </opt/dtk/hip/include/hip/amd_detail/amd_hip_bf16.h>
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const float           *alpha,
                                  const float           *Aarray[], int lda,
                                  const float           *Barray[], int ldb,
                                  const float           *beta,
                                  float           *Carray[], int ldc,
                                  int batchCount) {
    return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const double           *alpha,
                                  const double           *Aarray[], int lda,
                                  const double           *Barray[], int ldb,
                                  const double           *beta,
                                  double           *Carray[], int ldc,
                                  int batchCount) {
    return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const __half           *alpha,
                                  const __half           *Aarray[], int lda,
                                  const __half           *Barray[], int ldb,
                                  const __half           *beta,
                                  __half           *Carray[], int ldc,
                                  int batchCount) {
zhanggzh's avatar
zhanggzh committed
47
48
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
//#ifdef FMOE_USE_HIP
huchen1's avatar
huchen1 committed
49
50
    return rocblas_hgemm_batched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
zhanggzh's avatar
zhanggzh committed
51
52
//    return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
    return cublasHgemmBatched(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf* const*)Aarray, lda, (const hipblasHalf* const*)Barray, ldb, (const hipblasHalf*)beta, (hipblasHalf* const*)Carray, ldc, batchCount);
huchen1's avatar
huchen1 committed
53
#endif
Rick Ho's avatar
Rick Ho committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
}


inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const float           *alpha,
                                const float           *A, int lda,
                                const float           *B, int ldb,
                                const float           *beta,
                                float           *C, int ldc) {
    return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const double          *alpha,
                                const double          *A, int lda,
                                const double          *B, int ldb,
                                const double          *beta,
                                double          *C, int ldc) {
    return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const __half *alpha,
                                const __half *A, int lda,
                                const __half *B, int ldb,
                                const __half *beta,
                                __half *C, int ldc) {
zhanggzh's avatar
zhanggzh committed
87
88
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
huchen1's avatar
huchen1 committed
89
90
    return rocblas_hgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
zhanggzh's avatar
zhanggzh committed
91
92
//    return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
    return cublasHgemm(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf*)A, lda, (const hipblasHalf*)B, ldb, (const hipblasHalf*)beta, (hipblasHalf*)C, ldc);
huchen1's avatar
huchen1 committed
93
#endif
Rick Ho's avatar
Rick Ho committed
94
}
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101
102
103

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const c10::Half *alpha,
                                const c10::Half *A, int lda,
                                const c10::Half *B, int ldb,
                                const c10::Half *beta,
                                c10::Half *C, int ldc) {
zhanggzh's avatar
zhanggzh committed
104
105
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
Rick Ho's avatar
Rick Ho committed
106
107
108
109
110
    return rocblas_hgemm(handle, transa, transb, m, n, k,
            (const rocblas_half*)alpha,
            (const rocblas_half*)A, lda,
            (const rocblas_half*)B, ldb,
            (const rocblas_half*)beta,
huchen1's avatar
huchen1 committed
111
112
            (rocblas_half*)C, ldc);
#else
Rick Ho's avatar
Rick Ho committed
113
    return cublasHgemm(handle, transa, transb, m, n, k,
zhanggzh's avatar
zhanggzh committed
114
115
116
117
118
119
120
121
122
123
            //(const __half*)alpha,
            (const hipblasHalf*)alpha,
            //(const __half*)A, lda,
            (const hipblasHalf*)A, lda,
            //(const __half*)B, ldb,
            (const hipblasHalf*)B, ldb,
            //(const __half*)beta,
            (const hipblasHalf*)beta,
            //(__half*)C, ldc);
            (hipblasHalf*)C, ldc);
huchen1's avatar
huchen1 committed
124
#endif
Rick Ho's avatar
Rick Ho committed
125
}
Rick Ho's avatar
Rick Ho committed
126
127
128
129
130

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const c10::BFloat16 *alpha,
zhanggzh's avatar
zhanggzh committed
131
				//const void *alpha,
Rick Ho's avatar
Rick Ho committed
132
133
134
                                const c10::BFloat16 *A, int lda,
                                const c10::BFloat16 *B, int ldb,
                                const c10::BFloat16 *beta,
zhanggzh's avatar
zhanggzh committed
135
				//const void *beta,
Rick Ho's avatar
Rick Ho committed
136
                                c10::BFloat16 *C, int ldc) {
zhanggzh's avatar
zhanggzh committed
137
138
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
Rick Ho's avatar
Rick Ho committed
139
140
141
    // TODO: Support bf16 for HIP
    assert(false);
#else
zhanggzh's avatar
zhanggzh committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    //const float alpha_fp32(*alpha), beta_fp32(*beta);

    hipblasDatatype_t datatype_C =  HIPBLAS_R_16B;
    float alpha_ = static_cast<float>(*alpha);
    float beta_ = static_cast<float>(*beta);
    return hipblasGemmEx(handle, transa, transb, m, n, k,
            //(const float*)&alpha_fp32,
            //(const void*)A, datatype_C, lda,
            //(const void*)B, datatype_C, ldb,
            //(const float*)&beta_fp32,
            //(void*)C, datatype_C, ldc,
	    reinterpret_cast<const float*>(&alpha_),
	    //alpha,
            reinterpret_cast<const void*>(A), datatype_C, lda,
            reinterpret_cast<const void*>(B), datatype_C, ldb,
	    reinterpret_cast<const float*>(&beta_),
	    //beta,
            reinterpret_cast<void*>(C), datatype_C, ldc,
             HIPBLAS_R_32F,
             HIPBLAS_GEMM_DEFAULT);
Rick Ho's avatar
Rick Ho committed
162
163
#endif
}
Rick Ho's avatar
Rick Ho committed
164
165
#endif  // CUBLAS_WRAPPER_H