cuda_nvml.h 2.68 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_

yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
#include <nvml.h>
yuguo's avatar
yuguo committed
12
#endif
13
14
15
16
17
18
19
20
21

#include <string>

#include "../common.h"
#include "../util/string.h"

namespace transformer_engine {

namespace cuda_nvml {
yuguo's avatar
yuguo committed
22
#ifndef __HIP_PLATFORM_AMD__
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

/*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void *get_symbol(const char *symbol);

/*! \brief Call function in CUDA NVML library
 *
 * The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at
 * compile-time and run-time.
 *
 * \param[in] symbol Function name
 * \param[in] args   Function arguments
 */
template <typename... ArgTs>
inline nvmlReturn_t call(const char *symbol, ArgTs... args) {
  using FuncT = nvmlReturn_t(ArgTs...);
  FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
  return (*func)(args...);
}

/*! \brief Get NVML error string
 *
 * \param[in] rc NVML return code
 */
inline const char *get_nvml_error_string(nvmlReturn_t rc) {
  using FuncT = const char *(nvmlReturn_t);
  FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString"));
  return (*func)(rc);
}
yuguo's avatar
yuguo committed
51
#endif
52
53
54
55
56

}  // namespace cuda_nvml

}  // namespace transformer_engine

yuguo's avatar
yuguo committed
57
#ifndef __HIP_PLATFORM_AMD__
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#define NVTE_CHECK_CUDA_NVML(expr)                                                             \
  do {                                                                                         \
    const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr);                                   \
    if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) {                                         \
      const char *desc_NVTE_CHECK_CUDA_NVML =                                                  \
          ::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \
      NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML);                                   \
    }                                                                                          \
  } while (false)

#define VA_ARGS(...) , ##__VA_ARGS__
#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...)                                                 \
  do {                                                                                         \
    NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
  } while (false)
yuguo's avatar
yuguo committed
73
#endif
74
75

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_