Unverified Commit 7ddc5932 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Better cuBLAS handle management (#1389)



* Do not create multiple cublas handle
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix for multiple GPUs per thread
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix multithreaded execution
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix from conlfict
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4a74ef8c
......@@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t)
}
}
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }
namespace detail {
void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }
} // namespace detail
} // namespace transformer_engine
......@@ -68,6 +72,6 @@ namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
void* cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
......@@ -10,37 +10,25 @@
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cstdint>
#include <mutex>
#include <cudnn_graph.h>
#include "transformer_engine/transformer_engine.h"
#include "util/handle_manager.h"
namespace transformer_engine {
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
namespace detail {
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
void CreateCuDNNHandle(cudnnHandle_t* handle);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
} // namespace detail
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
~cudnnExecutionPlanManager() {}
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
private:
cudnnHandle_t handle_ = nullptr;
};
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
} // namespace transformer_engine
#endif
#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_
......@@ -329,7 +329,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
t = input_QKV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
......@@ -411,7 +411,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
t = input_QKV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
......@@ -511,7 +511,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
t_kv = input_KV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
......@@ -602,7 +602,7 @@ void nvte_fused_attn_bwd_kvpacked(
t_kv = input_KV->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
......@@ -699,7 +699,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
......@@ -786,7 +786,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
t_kv = input_K->data.shape[0];
}
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
......
......@@ -14,6 +14,7 @@
#include <mutex>
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
......@@ -47,6 +48,10 @@ uint32_t _getAlignment(uintptr_t address) {
}
}
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}
struct GemmParam {
void *A;
void *B;
......@@ -140,6 +145,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad,
......@@ -192,8 +199,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
float zero = 0.0;
float beta = (accumulate) ? one : zero;
cublasLtHandle_t handle;
NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
......
......@@ -211,7 +211,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
_handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
_handle = cudnnExecutionPlanManager::Instance().GetHandle();
_graph.set_io_data_type(get_cudnn_fe_dtype(itype))
.set_intermediate_data_type(get_cudnn_fe_dtype(ctype))
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace transformer_engine::detail {
template <typename Handle, void Create(Handle*), void Destroy(Handle) = nullptr>
class HandleManager {
public:
static HandleManager& Instance() {
static thread_local HandleManager instance;
return instance;
}
Handle GetHandle() {
static thread_local std::vector<bool> initialized(handles_.size(), false);
const int device_id = cuda::current_device();
NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID");
if (!initialized[device_id]) {
Create(&(handles_[device_id]));
initialized[device_id] = true;
}
return handles_[device_id];
}
~HandleManager() {
if (Destroy != nullptr) {
for (auto& handle : handles_) {
Destroy(handle);
}
}
}
private:
HandleManager() : handles_(cuda::num_devices(), nullptr) {}
std::vector<Handle> handles_ = nullptr;
};
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment