cublas_wrapper.h 8.28 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <iostream>
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>

#include <hipblas/hipblas.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include </opt/dtk/hip/include/hip/amd_detail/amd_hip_bf16.h>
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>
inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const float           *alpha,
                                  const float           *Aarray[], int lda,
                                  const float           *Barray[], int ldb,
                                  const float           *beta,
                                  float           *Carray[], int ldc,
                                  int batchCount) {
    return hipblasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const double           *alpha,
                                  const double           *Aarray[], int lda,
                                  const double           *Barray[], int ldb,
                                  const double           *beta,
                                  double           *Carray[], int ldc,
                                  int batchCount) {
    return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const __half           *alpha,
                                  const __half           *Aarray[], int lda,
                                  const __half           *Barray[], int ldb,
                                  const __half           *beta,
                                  __half           *Carray[], int ldc,
                                  int batchCount) {
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
//#ifdef FMOE_USE_HIP
    return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
//    return hipblasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
    return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf* const*)Aarray, lda, (const hipblasHalf* const*)Barray, ldb, (const hipblasHalf*)beta, (hipblasHalf* const*)Carray, ldc, batchCount);
#endif
}


inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const float           *alpha,
                                const float           *A, int lda,
                                const float           *B, int ldb,
                                const float           *beta,
                                float           *C, int ldc) {
    return hipblasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const double          *alpha,
                                const double          *A, int lda,
                                const double          *B, int ldb,
                                const double          *beta,
                                double          *C, int ldc) {
    return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const __half *alpha,
                                const __half *A, int lda,
                                const __half *B, int ldb,
                                const __half *beta,
                                __half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    return hipblasHgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
//    return hipblasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
    return hipblasHgemm(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf*)A, lda, (const hipblasHalf*)B, ldb, (const hipblasHalf*)beta, (hipblasHalf*)C, ldc);
#endif
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const c10::Half *alpha,
                                const c10::Half *A, int lda,
                                const c10::Half *B, int ldb,
                                const c10::Half *beta,
                                c10::Half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    return hipblasHgemm(handle, transa, transb, m, n, k,
            (const rocblas_half*)alpha,
            (const rocblas_half*)A, lda,
            (const rocblas_half*)B, ldb,
            (const rocblas_half*)beta,
            (rocblas_half*)C, ldc);
#else
    return hipblasHgemm(handle, transa, transb, m, n, k,
            //(const __half*)alpha,
            (const hipblasHalf*)alpha,
            //(const __half*)A, lda,
            (const hipblasHalf*)A, lda,
            //(const __half*)B, ldb,
            (const hipblasHalf*)B, ldb,
            //(const __half*)beta,
            (const hipblasHalf*)beta,
            //(__half*)C, ldc);
            (hipblasHalf*)C, ldc);
#endif
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const c10::BFloat16 *alpha,
				//const void *alpha,
                                const c10::BFloat16 *A, int lda,
                                const c10::BFloat16 *B, int ldb,
                                const c10::BFloat16 *beta,
				//const void *beta,
                                c10::BFloat16 *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    // TODO: Support bf16 for HIP
    assert(false);
#else
    //const float alpha_fp32(*alpha), beta_fp32(*beta);

    hipblasDatatype_t datatype_C =  HIPBLAS_R_16B;
    float alpha_ = static_cast<float>(*alpha);
    float beta_ = static_cast<float>(*beta);
    return hipblasGemmEx(handle, transa, transb, m, n, k,
            //(const float*)&alpha_fp32,
            //(const void*)A, datatype_C, lda,
            //(const void*)B, datatype_C, ldb,
            //(const float*)&beta_fp32,
            //(void*)C, datatype_C, ldc,
	    reinterpret_cast<const float*>(&alpha_),
	    //alpha,
            reinterpret_cast<const void*>(A), datatype_C, lda,
            reinterpret_cast<const void*>(B), datatype_C, ldb,
	    reinterpret_cast<const float*>(&beta_),
	    //beta,
            reinterpret_cast<void*>(C), datatype_C, ldc,
             HIPBLAS_R_32F,
             HIPBLAS_GEMM_DEFAULT);
#endif
}
#endif  // CUBLAS_WRAPPER_H