cublas_wrappers.h 3.15 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
/*
Copyright The Microsoft DeepSpeed Team
*/

5
6
7
8
9
10
11
#pragma once

#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
aiss's avatar
aiss committed
12
#ifndef __HIP_PLATFORM_HCC__
13
#include <mma.h>
aiss's avatar
aiss committed
14
#endif
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <stdio.h>

int cublas_gemm_ex(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
                   const float* alpha,
                   const float* beta,
                   const float* A,
                   const float* B,
                   float* C,
aiss's avatar
aiss committed
28
29
30
#ifdef __HIP_PLATFORM_HCC__
                   rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
31
                   cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
aiss's avatar
aiss committed
32
#endif
33
34
35
36
37
38
39
40
41
42
43
44

int cublas_gemm_ex(cublasHandle_t handle,
                   cublasOperation_t transa,
                   cublasOperation_t transb,
                   int m,
                   int n,
                   int k,
                   const float* alpha,
                   const float* beta,
                   const __half* A,
                   const __half* B,
                   __half* C,
aiss's avatar
aiss committed
45
46
47
#ifdef __HIP_PLATFORM_HCC__
                   rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
48
                   cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
aiss's avatar
aiss committed
49
#endif
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

int cublas_strided_batched_gemm(cublasHandle_t handle,
                                int m,
                                int n,
                                int k,
                                const float* alpha,
                                const float* beta,
                                const float* A,
                                const float* B,
                                float* C,
                                cublasOperation_t op_A,
                                cublasOperation_t op_B,
                                int stride_A,
                                int stride_B,
                                int stride_C,
                                int batch,
aiss's avatar
aiss committed
66
67
68
#ifdef __HIP_PLATFORM_HCC__
                                rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
69
                                cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
aiss's avatar
aiss committed
70
#endif
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

int cublas_strided_batched_gemm(cublasHandle_t handle,
                                int m,
                                int n,
                                int k,
                                const float* alpha,
                                const float* beta,
                                const __half* A,
                                const __half* B,
                                __half* C,
                                cublasOperation_t op_A,
                                cublasOperation_t op_B,
                                int stride_A,
                                int stride_B,
                                int stride_C,
                                int batch,
aiss's avatar
aiss committed
87
88
89
#ifdef __HIP_PLATFORM_HCC__
                                rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
90
                                cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
aiss's avatar
aiss committed
91
#endif