// SPDX-License-Identifier: MIT // #ifdef __gfx908__ // // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below // just for gfx908 and not for others // // below lines enable hip float to half conversion which are disabled by // default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef // __HIP_NO_HALF_CONVERSIONS__ #endif #include "hipbsolgemm.cuh" // #include // #ifdef USE_ROCM // #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + // ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL // (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif // #ifdef __HIP_PLATFORM_HCC__ // #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + // ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL // (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL // #ifdef ROCM_BACKWARD_PASS_GUARD // flag = at::BackwardPassGuard::is_backward_pass() ? // rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR #define CHECK_HIP_ERROR(error) \ if (error != hipSuccess) \ { \ fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ hipGetErrorString(error), error, __FILE__, __LINE__); \ exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLAS_ERROR #define CHECK_HIPBLAS_ERROR(error) \ if (error != HIPBLAS_STATUS_SUCCESS) \ { \ fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ hipblasStatusToString(error), error, __FILE__, __LINE__); \ exit(EXIT_FAILURE); \ } #endif namespace { /*thread_local*/ cudaStream_t weight_stream; // BUG: DLM has event and stream on different devices error // In multi-GPU scenerio, do names defined in this namespace exist on all // devices? C++ keyword: thread_local <- maybe this can help? /*thread_local*/ cudaEvent_t event; // hipBLASLt hipblasLtHandle_t hipblaslt_handle; hipblasLtMatmulPreference_t preference; size_t workspace_size = 2 * 128 * 1024 * 1024; // uint64_t workspace_size = 0; void *d_workspace; int request_solutions = 1; int returnedAlgoCount = 0; struct MatMulConfig { hipblasOperation_t op_A; hipblasOperation_t op_B; int M; int N; int K; hipblasDatatype_t dtype; friend auto operator<(const MatMulConfig &left, const MatMulConfig &right) -> bool { return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); } }; // std::map, // std::vector> heuristic_map; std::map heuristic_map; hipEvent_t start, stop; int bench_iters{1}; int warmup_iters{1}; bool cout_print = false; torch::Tensor dTensor; std::map dtype_map{ {at::kHalf, HIPBLAS_R_16F}, {at::kBFloat16, HIPBLAS_R_16B}, {at::kFloat, HIPBLAS_R_32F}, {at::kChar, HIPBLAS_R_8I} #ifdef ENABLE_TORCH_FP8 , {at::kFloat8_e4m3fn, HIP_R_8F_E4M3} #endif }; // std::vector heuristicResult; } // namespace // find all hipblaslt solutions for given gemm problem std::vector hipblasLtMatmul_findallsols_wrapper( hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, int m, int n, int k, const void *alpha, const void *a, int lda, const void *b, int ldb, const void *beta, void *c, int ldc, const void *bias, hipblasDatatype_t intype, hipblasDatatype_t outtype, const void *scaleA, const void *scaleB, const void *scaleC, const int scaleType, hipStream_t &stream) { int flag{0}; hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); CHECK_HIPBLAS_ERROR( hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIPBLAS_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); if (bias) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(void *))); auto epilogue = HIPBLASLT_EPILOGUE_BIAS; CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); } if (scaleA != nullptr) { if (scaleType == 1) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, &scaleA, sizeof(scaleA))); } else if (scaleType == 2) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_BLOC_EXT, &scaleA, sizeof(scaleA))); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); } } if (scaleB != nullptr) { if (scaleType == 1) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, &scaleB, sizeof(scaleB))); } else if (scaleType == 2) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_BLOC_EXT, &scaleB, sizeof(scaleB))); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); } } if (scaleC != nullptr) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); } // std::vector heuristicResult(10); // CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( // handle, matmul, matA, matB, matC, matC, // preference, 10, heuristicResult.data(), &returnedAlgoCount)); std::vector heuristicResult; CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos( handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); std::vector algoIndex; int returned_algo_count = heuristicResult.size(); // for (int i = 0; i < returnedAlgoCount; i++) { for (int i = 0; i < returned_algo_count; i++) { auto algo = heuristicResult[i].algo; size_t ret_workspace_size = 0; auto status = hipblaslt_ext::matmulIsAlgoSupported( handle, matmul, alpha, matA, matB, beta, matC, matC, algo, ret_workspace_size); if (status == HIPBLAS_STATUS_SUCCESS) { if (ret_workspace_size < workspace_size) { algoIndex.push_back(hipblaslt_ext::getIndexFromAlgo(algo)); } } } CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); return algoIndex; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call */ hipblasStatus_t hipblasLtMatmul_sol_wrapper( hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, int m, int n, int k, const void *alpha, const void *a, int lda, const void *scaleA, const void *b, int ldb, const void *scaleB, const void *beta, void *c, int ldc, const void *scaleC, const int scaleType, const void *bias, hipblasDatatype_t intype, hipblasDatatype_t outtype, hipStream_t &stream, int solution_index = -1) { // TODO: flag is not supported for hipblasLt yet int flag{0}; // if (dtype == HIPBLAS_R_16F) { // flag = rocblas_gemm_flags_fp16_alt_impl; //} // nvtxRangePushA("hipBLASLt variables creation"); hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); CHECK_HIPBLAS_ERROR( hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIPBLAS_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); if (scaleA != nullptr) { if (scaleType == 1) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, &scaleA, sizeof(scaleA))); } else if (scaleType == 2) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_BLOC_EXT, &scaleA, sizeof(scaleA))); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); } } if (scaleB != nullptr) { if (scaleType == 1) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, &scaleB, sizeof(scaleB))); } else if (scaleType == 2) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_BLOC_EXT, &scaleB, sizeof(scaleB))); } else { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); } } if (scaleC != nullptr) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); } if (bias) { CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(void *))); auto epilogue = HIPBLASLT_EPILOGUE_BIAS; static_assert(sizeof(epilogue) == sizeof(int32_t)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); } // nvtxRangePop(); // if heuristic does not exist in the map, do search and push into the map // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; // if (heuristic_map.count(gemm_key) <= 0) { std::vector heuristicResult(1); if (solution_index < 0) { // nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); if (cout_print) { std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; } if (cout_print) { std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") << " (" << m << ", " << n << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; } CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( handle, matmul, matA, matB, matC, matC, preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); if ((returnedAlgoCount != request_solutions) && cout_print) { std::cout << "less solution found! request: " << request_solutions << ", found: " << returnedAlgoCount << std::endl; } } else { std::vector algoIndex(1); algoIndex[0] = solution_index; CHECK_HIPBLAS_ERROR( hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); } hipblasStatus_t status = hipblasLtMatmul( handle, matmul, alpha, a, matA, b, matB, beta, c, matC, c, matC, &heuristicResult[0].algo, d_workspace, workspace_size, stream); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); return status; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, const int solution_index, std::optional bias, std::optional out_dtype, std::optional scaleA, std::optional scaleB, std::optional scaleOut, std::optional scaleType) { auto mat1_strides{mat1.strides()}; auto mat2_strides{mat2.strides()}; auto mat1_sizes{mat1.sizes()}; auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); TORCH_CHECK(mat1.dtype() == mat2.dtype(), "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()); TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); auto inDtype{mat1.options().dtype().toScalarType()}; auto outDtype{ out_dtype.has_value() ? torch::python::detail::py_object_to_dtype(out_dtype.value()) : inDtype}; auto options{at::TensorOptions().dtype(outDtype).device(at::kCUDA)}; auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { assert(false && "unusual strides detected, may need to clone a contiguous tensor"); } if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { assert(false && "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; transpose_mat1 = !transpose_mat2; transpose_mat2 = !tmp; mat1_strides = mat2.strides(); mat2_strides = mat1.strides(); mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } float one{1.0f}; float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); void *d_scaleA = nullptr, *d_scaleB = nullptr, *d_scaleOut = nullptr; int scale_type = 0; if (scaleA.has_value()) { d_scaleA = static_cast(scaleA.value().data_ptr()); } if (scaleB.has_value()) { d_scaleB = static_cast(scaleB.value().data_ptr()); } if (scaleOut.has_value()) { d_scaleOut = static_cast(scaleOut.value().data_ptr()); } if (scaleType.has_value()) { scale_type = scaleType.value(); } auto hipblasInType = dtype_map.at(inDtype); auto hipblasOutType = dtype_map.at(outDtype); void *ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; void *ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; void *ptrC{static_cast(result.data_ptr())}; if (transpose_result) std::swap(d_scaleA, d_scaleB); auto current_stream{torch::hip::getCurrentHIPStream().stream()}; void *bias_ptr = bias.has_value() ? static_cast(bias.value().data_ptr()) : nullptr; CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, mat1_ld, d_scaleA, ptrB, mat2_ld, d_scaleB, &zero, ptrC, result_ld, d_scaleOut, scale_type, bias_ptr, hipblasInType, hipblasOutType, current_stream, solution_index)); return result; } // find all hipblas solutions and return them to python land std::vector hipb_findallsols( const torch::Tensor &mat1, const torch::Tensor &mat2, std::optional bias, std::optional out_dtype, std::optional scaleA, std::optional scaleB, std::optional scaleC, std::optional scaleType) { auto mat1_strides{mat1.strides()}; auto mat2_strides{mat2.strides()}; auto mat1_sizes{mat1.sizes()}; auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); TORCH_CHECK(mat1.dtype() == mat2.dtype(), "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()); TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); auto inType{mat1.options().dtype().toScalarType()}; auto outType{ out_dtype.has_value() ? torch::python::detail::py_object_to_dtype(out_dtype.value()) : inType}; auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { assert(false && "unusual strides detected, may need to clone a contiguous tensor"); } if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { assert(false && "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; transpose_mat1 = !transpose_mat2; transpose_mat2 = !tmp; mat1_strides = mat2.strides(); mat2_strides = mat1.strides(); mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } float one{1.0f}; float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); hipblasDatatype_t hipblasInType = dtype_map.at(inType); hipblasDatatype_t hipblasOutType = dtype_map.at(outType); void *ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; void *ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; void *ptrC{static_cast(result.data_ptr())}; auto current_stream{torch::hip::getCurrentHIPStream().stream()}; auto bias_ptr = bias.has_value() ? static_cast(bias.value().data_ptr()) : nullptr; auto scaleA_ptr = scaleA.has_value() ? static_cast(scaleA.value().data_ptr()) : nullptr; auto scaleB_ptr = scaleB.has_value() ? static_cast(scaleB.value().data_ptr()) : nullptr; auto scaleC_ptr = scaleC.has_value() ? static_cast(scaleC.value().data_ptr()) : nullptr; int scale_type = scaleType.has_value() ? scaleType.value() : 0; return hipblasLtMatmul_findallsols_wrapper( hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, mat1_ld, ptrB, mat2_ld, &zero, ptrC, result_ld, bias_ptr, hipblasInType, hipblasOutType, scaleA_ptr, scaleB_ptr, scaleC_ptr, scale_type, current_stream); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// void hipb_create_extension() { // CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); // CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); // hipBLASLt CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); // CHECK_HIP_ERROR(hipEventCreate(&start)); // CHECK_HIP_ERROR(hipEventCreate(&stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// void hipb_destroy_extension() { // CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); // CHECK_HIP_ERROR(hipEventDestroy(event)); // hipBLASLt CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); CHECK_HIP_ERROR(hipFree(d_workspace)); // CHECK_HIP_ERROR(hipEventDestroy(start)); // CHECK_HIP_ERROR(hipEventDestroy(stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// std::string getHipblasltKernelName(int solution_index) { std::vector heuristicResult(1); std::vector algoIndex(1); algoIndex[0] = solution_index; CHECK_HIPBLAS_ERROR( hipblaslt_ext::getAlgosFromIndex(hipblaslt_handle, algoIndex, heuristicResult)); return hipblaslt_ext::getKernelNameFromAlgo(hipblaslt_handle, heuristicResult[0].algo); } #ifndef PREBUILD_KERNELS PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); m.def("hipb_mm", &hipb_mm, "hipb_mm", py::arg("mat1"), py::arg("mat2"), py::arg("solution_index"), py::arg("bias") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt, py::arg("scaleB") = std::nullopt, py::arg("scaleOut") = std::nullopt, py::arg("scaleType") = std::nullopt); m.def("hipb_findallsols", &hipb_findallsols, "hipb_findallsols", py::arg("mat1"), py::arg("mat2"), py::arg("bias") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("scaleA") = std::nullopt, py::arg("scaleB") = std::nullopt, py::arg("scaleC") = std::nullopt, py::arg("scaleType") = std::nullopt); m.def("getHipblasltKernelName", &getHipblasltKernelName); } #endif