cudnn_utils.h 1.06 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_
#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_

yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
12
13
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
14
#include <cudnn_graph.h>
yuguo's avatar
yuguo committed
15
#endif
16
17

#include "transformer_engine/transformer_engine.h"
18
#include "util/handle_manager.h"
19
20

namespace transformer_engine {
yuguo's avatar
yuguo committed
21
#ifndef __HIP_PLATFORM_AMD__
22
namespace detail {
23

24
void CreateCuDNNHandle(cudnnHandle_t* handle);
25

26
}  // namespace detail
27

28
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
29

30
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
31

32
using cudnnExecutionPlanManager = detail::HandleManager<cudnnHandle_t, detail::CreateCuDNNHandle>;
yuguo's avatar
yuguo committed
33
#endif
34
35
}  // namespace transformer_engine

36
#endif  //  TRANSFORMER_ENGINE_CUDNN_UTILS_H_