Commit 388ac735 authored by wenjh's avatar wenjh
Browse files

[rocblas] Use HandleManager to avoid mem leakage


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

[RocblasGemm] Provide support of AB(bf16)D(fp32)
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 08f14085
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <cstdint> #include <cstdint>
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../util/logging.h" #include "../util/logging.h"
...@@ -1352,6 +1353,12 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1352,6 +1353,12 @@ void hipblaslt_gemm(const Tensor *inputA,
#endif //USE_HIPBLASLT #endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion #ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
inline void CreateRocblasHandle(rocblas_handle *handle) {
NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
}
using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
void rocblas_gemm(const Tensor *inputA, void rocblas_gemm(const Tensor *inputA,
const Tensor *inputB, const Tensor *inputB,
Tensor *outputD, Tensor *outputD,
...@@ -1419,8 +1426,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1419,8 +1426,7 @@ void rocblas_gemm(const Tensor *inputA,
alpha = A_scale_inv * B_scale_inv; alpha = A_scale_inv * B_scale_inv;
} }
rocblas_handle handle; rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
NVTE_CHECK_ROCBLAS(rocblas_create_handle(&handle));
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream)); NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
// extract the stream order alloc env // extract the stream order alloc env
...@@ -1435,6 +1441,7 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1435,6 +1441,7 @@ void rocblas_gemm(const Tensor *inputA,
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) || NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) || (A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) || (A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) || (A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) || (A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) ||
...@@ -1524,8 +1531,6 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1524,8 +1531,6 @@ void rocblas_gemm(const Tensor *inputA,
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,flags)); computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,flags));
} }
NVTE_CHECK_ROCBLAS(rocblas_destroy_handle(handle));
int batch_size, input_dim, output_dim; int batch_size, input_dim, output_dim;
if (bias && gelu) { if (bias && gelu) {
if (grad) { if (grad) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment