cuda_driver.cpp 3.13 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <filesystem>

#include "../common.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace cuda_driver {

16
#ifndef USE_ROCM
17
18
19
20
21
typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int,
                                              unsigned long long,  // NOLINT(*)
                                              cudaDriverEntryPointQueryResult *);
typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long,  // NOLINT(*)
                                     cudaDriverEntryPointQueryResult *);
22
#endif
23
24

void *get_symbol(const char *symbol, int cuda_version) {
25
#ifndef USE_ROCM
26
27
28
29
30
31
32
33
  constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint";
  constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion";
  // We link to the libcudart.so already, so can search for it in the current context
  static GetEntryPoint driver_entrypoint_fun =
      reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint));
  static VersionedGetEntryPoint driver_entrypoint_versioned_fun =
      reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));

34
  cudaDriverEntryPointQueryResult driver_result;
35
#endif
36
  void *entry_point = nullptr;
yuguo's avatar
yuguo committed
37
38
39
40
41
42
#ifdef USE_ROCM
  hipDriverProcAddressQueryResult driver_result;
  NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
  NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
             "Could not find CUDA driver entry point for ", symbol);
#else
43
44
45
46
47
48
49
50
51
  if (driver_entrypoint_versioned_fun != nullptr) {
    // Found versioned entrypoint function
    NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version,
                                                    cudaEnableDefault, &driver_result));
  } else {
    NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop.");
    // Versioned entrypoint function not found
    NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result));
  }
52
53
  NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
             "Could not find CUDA driver entry point for ", symbol);
yuguo's avatar
yuguo committed
54
#endif
55
56
  return entry_point;
}
Tim Moon's avatar
Tim Moon committed
57

58
void ensure_context_exists() {
59
60
61
62
63
64
65
66
67
68
69
70
71
  static thread_local bool need_check = []() {
    CUcontext context;
    NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
    if (context == nullptr) {
      // Add primary context to context stack
      CUdevice device;
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
      NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
    }
    return false;
  }();
72
73
}

Tim Moon's avatar
Tim Moon committed
74
75
76
}  // namespace cuda_driver

}  // namespace transformer_engine