cudnn_utils.h 1011 Bytes
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_

#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
13
#include <cudnn_graph.h>
14
15

#include "transformer_engine/transformer_engine.h"
16
#include "util/handle_manager.h"
17
18
19

namespace transformer_engine {

20
namespace detail {
21

22
void CreateCuDNNHandle(cudnnHandle_t* handle);
23

24
}  // namespace detail
25

26
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
27

28
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
29

30
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
31
32
33

}  // namespace transformer_engine

34
#endif  //  TRANSFORMER_ENGINE_CUDNN_UTILS_H_