#pragma once #include #include #include #include #include #include #include int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, const float* alpha, const float* beta, const float* A, const float* B, float* C, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, const float* alpha, const float* beta, const __half* A, const __half* B, __half* C, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, int k, const float* alpha, const float* beta, const float* A, const float* B, float* C, rocblas_operation op_A, rocblas_operation op_B, int stride_A, int stride_B, int stride_C, int batch, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, int k, const float* alpha, const float* beta, const __half* A, const __half* B, __half* C, rocblas_operation op_A, rocblas_operation op_B, int stride_A, int stride_B, int stride_C, int batch, cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);