cuda_driver.cpp 2.19 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <filesystem>

#include "../common.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace cuda_driver {

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int,
                                              unsigned long long,  // NOLINT(*)
                                              cudaDriverEntryPointQueryResult *);
typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long,  // NOLINT(*)
                                     cudaDriverEntryPointQueryResult *);

void *get_symbol(const char *symbol, int cuda_version) {
  constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint";
  constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion";
  // We link to the libcudart.so already, so can search for it in the current context
  static GetEntryPoint driver_entrypoint_fun =
      reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint));
  static VersionedGetEntryPoint driver_entrypoint_versioned_fun =
      reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));

31
  cudaDriverEntryPointQueryResult driver_result;
32
33
34
35
36
37
38
39
40
41
  void *entry_point = nullptr;
  if (driver_entrypoint_versioned_fun != nullptr) {
    // Found versioned entrypoint function
    NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version,
                                                    cudaEnableDefault, &driver_result));
  } else {
    NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop.");
    // Versioned entrypoint function not found
    NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result));
  }
42
43
44
45
  NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
             "Could not find CUDA driver entry point for ", symbol);
  return entry_point;
}
Tim Moon's avatar
Tim Moon committed
46
47
48
49

}  // namespace cuda_driver

}  // namespace transformer_engine