cublas_wrapper.h 6.54 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
Rick Ho's avatar
Rick Ho committed
3
#include <cublas_v2.h>
Rick Ho's avatar
Rick Ho committed
4
#include <c10/util/Half.h>
Rick Ho's avatar
Rick Ho committed
5
#include <c10/util/BFloat16.h>
Rick Ho's avatar
Rick Ho committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const float           *alpha,
                                  const float           *Aarray[], int lda,
                                  const float           *Barray[], int ldb,
                                  const float           *beta,
                                  float           *Carray[], int ldc,
                                  int batchCount) {
    return cublasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const double           *alpha,
                                  const double           *Aarray[], int lda,
                                  const double           *Barray[], int ldb,
                                  const double           *beta,
                                  double           *Carray[], int ldc,
                                  int batchCount) {
    return cublasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
                                  cublasOperation_t transa,
                                  cublasOperation_t transb,
                                  int m, int n, int k,
                                  const __half           *alpha,
                                  const __half           *Aarray[], int lda,
                                  const __half           *Barray[], int ldb,
                                  const __half           *beta,
                                  __half           *Carray[], int ldc,
                                  int batchCount) {
hclearner's avatar
hclearner committed
43
#ifdef FMOE_USE_HIP
huchen1's avatar
huchen1 committed
44
45
    return rocblas_hgemm_batched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
Rick Ho's avatar
Rick Ho committed
46
    return cublasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
huchen1's avatar
huchen1 committed
47
#endif
Rick Ho's avatar
Rick Ho committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
}


inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const float           *alpha,
                                const float           *A, int lda,
                                const float           *B, int ldb,
                                const float           *beta,
                                float           *C, int ldc) {
    return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const double          *alpha,
                                const double          *A, int lda,
                                const double          *B, int ldb,
                                const double          *beta,
                                double          *C, int ldc) {
    return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const __half *alpha,
                                const __half *A, int lda,
                                const __half *B, int ldb,
                                const __half *beta,
                                __half *C, int ldc) {
hclearner's avatar
hclearner committed
81
#ifdef FMOE_USE_HIP
huchen1's avatar
huchen1 committed
82
83
    return rocblas_hgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
Rick Ho's avatar
Rick Ho committed
84
    return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
huchen1's avatar
huchen1 committed
85
#endif
Rick Ho's avatar
Rick Ho committed
86
}
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
92
93
94
95

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const c10::Half *alpha,
                                const c10::Half *A, int lda,
                                const c10::Half *B, int ldb,
                                const c10::Half *beta,
                                c10::Half *C, int ldc) {
hclearner's avatar
hclearner committed
96
#ifdef FMOE_USE_HIP
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
    return rocblas_hgemm(handle, transa, transb, m, n, k,
            (const rocblas_half*)alpha,
            (const rocblas_half*)A, lda,
            (const rocblas_half*)B, ldb,
            (const rocblas_half*)beta,
huchen1's avatar
huchen1 committed
102
103
            (rocblas_half*)C, ldc);
#else
Rick Ho's avatar
Rick Ho committed
104
105
106
107
108
    return cublasHgemm(handle, transa, transb, m, n, k,
            (const __half*)alpha,
            (const __half*)A, lda,
            (const __half*)B, ldb,
            (const __half*)beta,
Rick Ho's avatar
Rick Ho committed
109
            (__half*)C, ldc);
huchen1's avatar
huchen1 committed
110
#endif
Rick Ho's avatar
Rick Ho committed
111
}
Rick Ho's avatar
Rick Ho committed
112
113
114
115
116
117
118
119
120
121
122
123
124

inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
                                cublasOperation_t transa, cublasOperation_t transb,
                                int m, int n, int k,
                                const c10::BFloat16 *alpha,
                                const c10::BFloat16 *A, int lda,
                                const c10::BFloat16 *B, int ldb,
                                const c10::BFloat16 *beta,
                                c10::BFloat16 *C, int ldc) {
#ifdef FMOE_USE_HIP
    // TODO: Support bf16 for HIP
    assert(false);
#else
xptree's avatar
xptree committed
125
    const float alpha_fp32(*alpha), beta_fp32(*beta);
Rick Ho's avatar
Rick Ho committed
126
    return cublasSgemmEx(handle, transa, transb, m, n, k,
xptree's avatar
xptree committed
127
128
129
130
131
            (const float*)&alpha_fp32,
            (const void*)A, CUDA_R_16BF, lda,
            (const void*)B, CUDA_R_16BF, ldb,
            (const float*)&beta_fp32,
            (void*)C, CUDA_R_16BF, ldc);
Rick Ho's avatar
Rick Ho committed
132
133
#endif
}
Rick Ho's avatar
Rick Ho committed
134
135
#endif  // CUBLAS_WRAPPER_H