cuda_driver.h 2.59 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_

#include <cuda.h>

12
13
#include <string>

Tim Moon's avatar
Tim Moon committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "../common.h"
#include "../util/string.h"

namespace transformer_engine {

namespace cuda_driver {

/*! \brief Get pointer corresponding to symbol in CUDA driver library */
void *get_symbol(const char *symbol);

/*! \brief Call function in CUDA driver library
 *
 * The CUDA driver library (libcuda.so.1 on Linux) may be different at
 * compile-time and run-time. In particular, the CUDA Toolkit provides
 * stubs for the driver library in case compilation is on a system
 * without GPUs. Indirect function calls into a lazily-initialized
 * library ensures we are accessing the correct version.
 *
 * \param[in] symbol Function name
 * \param[in] args   Function arguments
 */
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
  using FuncT = CUresult(ArgTs...);
38
  FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
Tim Moon's avatar
Tim Moon committed
39
40
41
42
43
44
45
  return (*func)(args...);
}

}  // namespace cuda_driver

}  // namespace transformer_engine

46
47
48
49
50
51
52
53
54
#define NVTE_CHECK_CUDA_DRIVER(expr)                                                             \
  do {                                                                                           \
    const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr);                                       \
    if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) {                                         \
      const char *desc_NVTE_CHECK_CUDA_DRIVER;                                                   \
      ::transformer_engine::cuda_driver::call("cuGetErrorString", status_NVTE_CHECK_CUDA_DRIVER, \
                                              &desc_NVTE_CHECK_CUDA_DRIVER);                     \
      NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);                                   \
    }                                                                                            \
Tim Moon's avatar
Tim Moon committed
55
56
  } while (false)

57
58
59
#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...)                                           \
  do {                                                                                     \
    NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
Tim Moon's avatar
Tim Moon committed
60
  } while (false)
Tim Moon's avatar
Tim Moon committed
61
62

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_