Unverified Commit 2643ba1d authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[C] Separating cudnn common utils from fused_attn (#1314)



* split cudnn utils from fused_attn/util
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e5ffaa76
...@@ -46,7 +46,6 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) ...@@ -46,7 +46,6 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..) include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES list(APPEND transformer_engine_SOURCES
pycudnn.cpp
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu common.cu
......
...@@ -4,13 +4,70 @@ ...@@ -4,13 +4,70 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../fused_attn/utils.h" #include "cudnn_utils.h"
#include "./util/logging.h"
#include "transformer_engine/cudnn.h" #include "transformer_engine/cudnn.h"
namespace transformer_engine { namespace transformer_engine {
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
void nvte_cudnn_handle_init() { void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
} }
} // namespace transformer_engine } // namespace transformer_engine
namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#include <cstdint>
#include <mutex>
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager {
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine
#endif
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include "../util/system.h" #include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h" #include "fused_attn_f16_arbitrary_seqlen.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <vector> #include <vector>
#include "../common.h" #include "../common.h"
#include "../cudnn_utils.h"
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include "../util/system.h" #include "../util/system.h"
#include "fused_attn_f16_arbitrary_seqlen.h" #include "fused_attn_f16_arbitrary_seqlen.h"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <vector> #include <vector>
#include "../common.h" #include "../common.h"
#include "../cudnn_utils.h"
#include "fused_attn_f16_max512_seqlen.h" #include "fused_attn_f16_max512_seqlen.h"
#include "utils.h" #include "utils.h"
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include "../common.h" #include "../common.h"
#include "../cudnn_utils.h"
#include "../util/system.h" #include "../util/system.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
#include "utils.h" #include "utils.h"
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <cmath> #include <cmath>
#include "../common.h" #include "../common.h"
#include "../cudnn_utils.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "utils.h" #include "utils.h"
...@@ -495,50 +496,4 @@ size_t get_max_tokens(size_t num_tokens) { ...@@ -495,50 +496,4 @@ size_t get_max_tokens(size_t num_tokens) {
} }
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
return CUDNN_DATA_FLOAT;
case DType::kBFloat16:
return CUDNN_DATA_BFLOAT16;
case DType::kFloat8E4M3:
return CUDNN_DATA_FP8_E4M3;
case DType::kFloat8E5M2:
return CUDNN_DATA_FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return cudnn_frontend::DataType_t::INT32;
case DType::kInt64:
return cudnn_frontend::DataType_t::INT64;
case DType::kFloat16:
return cudnn_frontend::DataType_t::HALF;
case DType::kFloat32:
return cudnn_frontend::DataType_t::FLOAT;
case DType::kBFloat16:
return cudnn_frontend::DataType_t::BFLOAT16;
case DType::kFloat8E4M3:
return cudnn_frontend::DataType_t::FP8_E4M3;
case DType::kFloat8E5M2:
return cudnn_frontend::DataType_t::FP8_E5M2;
default:
NVTE_ERROR("Invalid cuDNN data type. \n");
}
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -140,29 +140,8 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at ...@@ -140,29 +140,8 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size); size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens); size_t get_max_tokens(size_t num_tokens);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
class cudnnExecutionPlanManager { } // namespace fused_attn
public:
static cudnnExecutionPlanManager &Instance() {
static thread_local cudnnExecutionPlanManager instance;
return instance;
}
cudnnHandle_t GetCudnnHandle() {
static thread_local std::once_flag flag;
std::call_once(flag, [&] { cudnnCreate(&handle_); });
return handle_;
}
~cudnnExecutionPlanManager() {}
private:
cudnnHandle_t handle_ = nullptr;
};
} // namespace transformer_engine } // namespace transformer_engine
#endif #endif
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
namespace cudnn_frontend {
// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
void *cudnn_dlhandle = nullptr;
} // namespace cudnn_frontend
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment