// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <iostream>
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>

#include <hipblas/hipblas.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include </opt/dtk/hip/include/hip/amd_detail/amd_hip_bf16.h>
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>
inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const float           *alpha,
                                  const float           *Aarray[], int lda,
                                  const float           *Barray[], int ldb,
                                  const float           *beta,
                                  float           *Carray[], int ldc,
                                  int batchCount) {
    return hipblasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const double           *alpha,
                                  const double           *Aarray[], int lda,
                                  const double           *Barray[], int ldb,
                                  const double           *beta,
                                  double           *Carray[], int ldc,
                                  int batchCount) {
    return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}

inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
                                  hipblasOperation_t transa,
                                  hipblasOperation_t transb,
                                  int m, int n, int k,
                                  const __half           *alpha,
                                  const __half           *Aarray[], int lda,
                                  const __half           *Barray[], int ldb,
                                  const __half           *beta,
                                  __half           *Carray[], int ldc,
                                  int batchCount) {
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
//#ifdef FMOE_USE_HIP
    return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
//    return hipblasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
    return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf* const*)Aarray, lda, (const hipblasHalf* const*)Barray, ldb, (const hipblasHalf*)beta, (hipblasHalf* const*)Carray, ldc, batchCount);
#endif
}


inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const float           *alpha,
                                const float           *A, int lda,
                                const float           *B, int ldb,
                                const float           *beta,
                                float           *C, int ldc) {
    return hipblasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const double          *alpha,
                                const double          *A, int lda,
                                const double          *B, int ldb,
                                const double          *beta,
                                double          *C, int ldc) {
    return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const __half *alpha,
                                const __half *A, int lda,
                                const __half *B, int ldb,
                                const __half *beta,
                                __half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    return hipblasHgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
//    return hipblasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
    return hipblasHgemm(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf*)A, lda, (const hipblasHalf*)B, ldb, (const hipblasHalf*)beta, (hipblasHalf*)C, ldc);
#endif
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const c10::Half *alpha,
                                const c10::Half *A, int lda,
                                const c10::Half *B, int ldb,
                                const c10::Half *beta,
                                c10::Half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    return hipblasHgemm(handle, transa, transb, m, n, k,
            (const rocblas_half*)alpha,
            (const rocblas_half*)A, lda,
            (const rocblas_half*)B, ldb,
            (const rocblas_half*)beta,
            (rocblas_half*)C, ldc);
#else
    return hipblasHgemm(handle, transa, transb, m, n, k,
            //(const __half*)alpha,
            (const hipblasHalf*)alpha,
            //(const __half*)A, lda,
            (const hipblasHalf*)A, lda,
            //(const __half*)B, ldb,
            (const hipblasHalf*)B, ldb,
            //(const __half*)beta,
            (const hipblasHalf*)beta,
            //(__half*)C, ldc);
            (hipblasHalf*)C, ldc);
#endif
}

inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
                                hipblasOperation_t transa, hipblasOperation_t transb,
                                int m, int n, int k,
                                const c10::BFloat16 *alpha,
				//const void *alpha,
                                const c10::BFloat16 *A, int lda,
                                const c10::BFloat16 *B, int ldb,
                                const c10::BFloat16 *beta,
				//const void *beta,
                                c10::BFloat16 *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
    // TODO: Support bf16 for HIP
    assert(false);
#else
    //const float alpha_fp32(*alpha), beta_fp32(*beta);

    hipblasDatatype_t datatype_C =  HIPBLAS_R_16B;
    float alpha_ = static_cast<float>(*alpha);
    float beta_ = static_cast<float>(*beta);
    return hipblasGemmEx(handle, transa, transb, m, n, k,
            //(const float*)&alpha_fp32,
            //(const void*)A, datatype_C, lda,
            //(const void*)B, datatype_C, ldb,
            //(const float*)&beta_fp32,
            //(void*)C, datatype_C, ldc,
	    reinterpret_cast<const float*>(&alpha_),
	    //alpha,
            reinterpret_cast<const void*>(A), datatype_C, lda,
            reinterpret_cast<const void*>(B), datatype_C, ldb,
	    reinterpret_cast<const float*>(&beta_),
	    //beta,
            reinterpret_cast<void*>(C), datatype_C, ldc,
             HIPBLAS_R_32F,
             HIPBLAS_GEMM_DEFAULT);
#endif
}
#endif  // CUBLAS_WRAPPER_H

