Unverified Commit 2643ba1d authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[C] Separating cudnn common utils from fused_attn (#1314)



* split cudnn utils from fused_attn/util
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e5ffaa76
......@@ -46,7 +46,6 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
cudnn_utils.cpp
transformer_engine.cpp
common.cu
......
......@@ -4,13 +4,70 @@
* See LICENSE for license information.
************************************************************************/
#include "../fused_attn/utils.h"
#include "cudnn_utils.h"
#include "./util/logging.h"
#include "transformer_engine/cudnn.h"
namespace transformer_engine {
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
} // namespace transformer_engine
namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cstdint>
#include <mutex>
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine
#endif
......@@ -7,6 +7,7 @@
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
......
......@@ -13,6 +13,7 @@
#include <vector>
#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
......
......@@ -12,6 +12,7 @@
#include <vector>
#include "../common.h"
#include "../cudnn_utils.h"
#include "fused_attn_f16_max512_seqlen.h"
#include "utils.h"
......
......@@ -5,6 +5,7 @@
************************************************************************/
#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/system.h"
#include "fused_attn_fp8.h"
#include "utils.h"
......
......@@ -8,6 +8,7 @@
#include <cmath>
#include "../common.h"
#include "../cudnn_utils.h"
#include "transformer_engine/fused_attn.h"
#include "utils.h"
......@@ -495,50 +496,4 @@ size_t get_max_tokens(size_t num_tokens) {
}
} // namespace fused_attn
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine
......@@ -140,29 +140,8 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace fused_attn
} // namespace transformer_engine
#endif
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment