Commit f4bd89eb authored by wenjh's avatar wenjh
Browse files

Fix hipblaslt handle manage

parent a13c52ad
......@@ -465,106 +465,6 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
namespace {
static class HandlePool {
public:
hipblasLtHandle_t get(int device_id) {
std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) {
int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count);
return nullptr;
}
if (!pool[device_id].empty()) {
hipblasLtHandle_t h = pool[device_id].front();
pool[device_id].pop_front();
return h;
}
return nullptr;
}
hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id);
if (h == nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
}
return h;
}
void store(const std::vector<hipblasLtHandle_t>& handles) {
std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) {
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
}
for (unsigned int i = 0; i < pool.size(); i++) {
if (handles[i] != nullptr) {
pool[i].push_front(handles[i]);
}
}
}
~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL
std::lock_guard<std::mutex> lock(mt);
for (auto& hlist : pool) {
for (auto& h : hlist) {
hipblasLtDestroy(h);
}
}
pool.clear();
#endif
}
inline size_t get_size() const { return pool.size(); }
private:
std::mutex mt;
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
// Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL
Pool pool;
#else
Pool& pool = *new Pool();
#endif
} handle_pool;
thread_local static class HandleCache {
public:
hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; }
hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id);
if (h) {
return h;
}
h = handle_pool.obtain(device_id);
set(device_id, h);
return h;
}
void set(int device_id, hipblasLtHandle_t h) {
if (d.empty()) {
d.resize(handle_pool.get_size());
}
d[device_id] = h;
}
~HandleCache() {
if (!d.empty()) {
handle_pool.store(d);
}
}
private:
std::vector<hipblasLtHandle_t> d;
} cached_handles;
class csv_helper {
public:
struct start {};
......@@ -987,18 +887,12 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
} //namespace
/* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently.
*/
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < compute_num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
}
static inline void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle));
}
using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>;
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine;
switch (t) {
......@@ -1018,8 +912,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
bool grad, void* workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream,
hipblasLtHandle_t handle) {
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr;
float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
......@@ -1064,12 +957,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
if (handle == nullptr) {
handle = cached_handles.get(device_id);
if (handle == nullptr) {
handle = cached_handles.obtain(device_id);
}
}
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
......@@ -1403,15 +1291,7 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
......@@ -1929,20 +1809,10 @@ void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
m_split, n_split, gemm_producer, inputCounter, stream, handle);
m_split, n_split, gemm_producer, inputCounter, stream);
return;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment