cuda_driver.cpp 1.14 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <filesystem>

#include "../common.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace cuda_driver {

16
17
void *get_symbol(const char *symbol) {
  void *entry_point;
yuguo's avatar
yuguo committed
18
19
20
21
22
23
#ifdef USE_ROCM
  hipDriverProcAddressQueryResult driver_result;
  NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
  NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
             "Could not find CUDA driver entry point for ", symbol);
#else
24
25
26
27
  cudaDriverEntryPointQueryResult driver_result;
  NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
  NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
             "Could not find CUDA driver entry point for ", symbol);
yuguo's avatar
yuguo committed
28
#endif
29
30
  return entry_point;
}
Tim Moon's avatar
Tim Moon committed
31
32
33
34

}  // namespace cuda_driver

}  // namespace transformer_engine