Commit 388ac735 authored by wenjh's avatar wenjh
Browse files

[rocblas] Use HandleManager to avoid mem leakage


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

[RocblasGemm] Provide support of AB(bf16)D(fp32)
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 08f14085
......@@ -29,6 +29,7 @@
#include <cstdint>
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/vectorized_pointwise.h"
#include "../util/logging.h"
......@@ -1352,6 +1353,12 @@ void hipblaslt_gemm(const Tensor *inputA,
#endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
inline void CreateRocblasHandle(rocblas_handle *handle) {
NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
}
using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
void rocblas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
......@@ -1419,8 +1426,7 @@ void rocblas_gemm(const Tensor *inputA,
alpha = A_scale_inv * B_scale_inv;
}
rocblas_handle handle;
NVTE_CHECK_ROCBLAS(rocblas_create_handle(&handle));
rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
// extract the stream order alloc env
......@@ -1435,6 +1441,7 @@ void rocblas_gemm(const Tensor *inputA,
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) ||
......@@ -1524,8 +1531,6 @@ void rocblas_gemm(const Tensor *inputA,
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,flags));
}
NVTE_CHECK_ROCBLAS(rocblas_destroy_handle(handle));
int batch_size, input_dim, output_dim;
if (bias && gelu) {
if (grad) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment