cudnn_utils.cpp 2.23 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#include "cudnn_utils.h"

#include "./util/logging.h"
10
11
12
13
#include "transformer_engine/cudnn.h"

namespace transformer_engine {

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// get cuDNN data type
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kInt32:
      return CUDNN_DATA_INT32;
    case DType::kInt64:
      return CUDNN_DATA_INT64;
    case DType::kFloat16:
      return CUDNN_DATA_HALF;
    case DType::kFloat32:
      return CUDNN_DATA_FLOAT;
    case DType::kBFloat16:
      return CUDNN_DATA_BFLOAT16;
    case DType::kFloat8E4M3:
      return CUDNN_DATA_FP8_E4M3;
    case DType::kFloat8E5M2:
      return CUDNN_DATA_FP8_E5M2;
    default:
      NVTE_ERROR("Invalid cuDNN data type. \n");
  }
}

// get cuDNN data type
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kInt32:
      return cudnn_frontend::DataType_t::INT32;
    case DType::kInt64:
      return cudnn_frontend::DataType_t::INT64;
    case DType::kFloat16:
      return cudnn_frontend::DataType_t::HALF;
    case DType::kFloat32:
      return cudnn_frontend::DataType_t::FLOAT;
    case DType::kBFloat16:
      return cudnn_frontend::DataType_t::BFLOAT16;
    case DType::kFloat8E4M3:
      return cudnn_frontend::DataType_t::FP8_E4M3;
    case DType::kFloat8E5M2:
      return cudnn_frontend::DataType_t::FP8_E5M2;
    default:
      NVTE_ERROR("Invalid cuDNN data type. \n");
  }
}

60
61
62
63
64
65
66
void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); }

namespace detail {

void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); }

}  // namespace detail
67
68

}  // namespace transformer_engine
69
70
71
72
73
74

namespace cudnn_frontend {

// This is needed to define the symbol `cudnn_dlhandle`
// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
// to enable dynamic loading.
75
void* cudnn_dlhandle = nullptr;
76
77

}  // namespace cudnn_frontend