Unverified Commit 7ddc5932 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Better cuBLAS handle management (#1389)



* Do not create multiple cublas handle
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix for multiple GPUs per thread
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix multithreaded execution
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix from conlfict
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4a74ef8c
...@@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) ...@@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
} }
} }
void nvte_cudnn_handle_init() { void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
} namespace detail {
void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }
} // namespace detail
} // namespace transformer_engine } // namespace transformer_engine
...@@ -68,6 +72,6 @@ namespace cudnn_frontend { ...@@ -68,6 +72,6 @@ namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle` // This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading. // to enable dynamic loading.
void *cudnn_dlhandle = nullptr; void* cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend } // namespace cudnn_frontend
...@@ -10,37 +10,25 @@ ...@@ -10,37 +10,25 @@
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#include <cudnn_graph.h>
#include <cstdint>
#include <mutex>
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "util/handle_manager.h"
namespace transformer_engine { namespace transformer_engine {
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); namespace detail {
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); void CreateCuDNNHandle(cudnnHandle_t* handle);
class cudnnExecutionPlanManager { } // namespace detail
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() { cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {} cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
private: using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine } // namespace transformer_engine
#endif #endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_
...@@ -329,7 +329,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -329,7 +329,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
t = input_QKV->data.shape[0]; t = input_QKV->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
...@@ -411,7 +411,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -411,7 +411,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
t = input_QKV->data.shape[0]; t = input_QKV->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
...@@ -511,7 +511,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -511,7 +511,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
...@@ -602,7 +602,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -602,7 +602,7 @@ void nvte_fused_attn_bwd_kvpacked(
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
...@@ -699,7 +699,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -699,7 +699,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
...@@ -786,7 +786,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -786,7 +786,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <mutex> #include <mutex>
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
...@@ -47,6 +48,10 @@ uint32_t _getAlignment(uintptr_t address) { ...@@ -47,6 +48,10 @@ uint32_t _getAlignment(uintptr_t address) {
} }
} }
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}
struct GemmParam { struct GemmParam {
void *A; void *A;
void *B; void *B;
...@@ -140,6 +145,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -140,6 +145,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
namespace transformer_engine { namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
...@@ -192,8 +199,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -192,8 +199,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
float zero = 0.0; float zero = 0.0;
float beta = (accumulate) ? one : zero; float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
......
...@@ -211,7 +211,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor ...@@ -211,7 +211,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;); wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
_handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); _handle = cudnnExecutionPlanManager::Instance().GetHandle();
_graph.set_io_data_type(get_cudnn_fe_dtype(itype)) _graph.set_io_data_type(get_cudnn_fe_dtype(itype))
.set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) .set_intermediate_data_type(get_cudnn_fe_dtype(ctype))
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace transformer_engine::detail {
template <typename Handle, void Create(Handle*), void Destroy(Handle) = nullptr>
class HandleManager {
public:
static HandleManager& Instance() {
static thread_local HandleManager instance;
return instance;
}
Handle GetHandle() {
static thread_local std::vector<bool> initialized(handles_.size(), false);
const int device_id = cuda::current_device();
NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID");
if (!initialized[device_id]) {
Create(&(handles_[device_id]));
initialized[device_id] = true;
}
return handles_[device_id];
}
~HandleManager() {
if (Destroy != nullptr) {
for (auto& handle : handles_) {
Destroy(handle);
}
}
}
private:
HandleManager() : handles_(cuda::num_devices(), nullptr) {}
std::vector<Handle> handles_ = nullptr;
};
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment