/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "cudnn_utils.h" #include "./util/logging.h" #include "transformer_engine/cudnn.h" namespace transformer_engine { #ifndef USE_ROCM // get cuDNN data type cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { using namespace transformer_engine; switch (t) { case DType::kInt32: return CUDNN_DATA_INT32; case DType::kInt64: return CUDNN_DATA_INT64; case DType::kFloat16: return CUDNN_DATA_HALF; case DType::kFloat32: return CUDNN_DATA_FLOAT; case DType::kBFloat16: return CUDNN_DATA_BFLOAT16; case DType::kFloat8E4M3: return CUDNN_DATA_FP8_E4M3; case DType::kFloat8E5M2: return CUDNN_DATA_FP8_E5M2; default: NVTE_ERROR("Invalid cuDNN data type. \n"); } } // get cuDNN data type cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { using namespace transformer_engine; switch (t) { case DType::kInt32: return cudnn_frontend::DataType_t::INT32; case DType::kInt64: return cudnn_frontend::DataType_t::INT64; case DType::kFloat16: return cudnn_frontend::DataType_t::HALF; case DType::kFloat32: return cudnn_frontend::DataType_t::FLOAT; case DType::kBFloat16: return cudnn_frontend::DataType_t::BFLOAT16; case DType::kFloat8E4M3: return cudnn_frontend::DataType_t::FP8_E4M3; case DType::kFloat8E5M2: return cudnn_frontend::DataType_t::FP8_E5M2; default: NVTE_ERROR("Invalid cuDNN data type. \n"); } } void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } namespace detail { void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } } // namespace detail #endif } // namespace transformer_engine namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend