cuda_driver.cpp 2.88 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <dlfcn.h>
8

Tim Moon's avatar
Tim Moon committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <filesystem>

#include "../common.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace {

/*! \brief Wrapper class for a shared library
 *
 * \todo Windows support
 */
class Library {
 public:
  explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
    // TODO Windows support
    NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
    handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
    NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif  // _WIN32 or _WIN64 or __WINDOW__
  }

  ~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
    // TODO Windows support
#else
    if (handle_ != nullptr) {
      dlclose(handle_);
    }
#endif  // _WIN32 or _WIN64 or __WINDOW__
  }

44
  Library(const Library &) = delete;  // move-only
Tim Moon's avatar
Tim Moon committed
45

46
  Library(Library &&other) noexcept { swap(*this, other); }
Tim Moon's avatar
Tim Moon committed
47

48
  Library &operator=(Library other) noexcept {
Tim Moon's avatar
Tim Moon committed
49
50
51
52
53
    // Copy-and-swap idiom
    swap(*this, other);
    return *this;
  }

54
  friend void swap(Library &first, Library &second) noexcept;
Tim Moon's avatar
Tim Moon committed
55

56
  void *get() noexcept { return handle_; }
Tim Moon's avatar
Tim Moon committed
57

58
  const void *get() const noexcept { return handle_; }
Tim Moon's avatar
Tim Moon committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

  /*! \brief Get pointer corresponding to symbol in shared library */
  void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
    // TODO Windows support
    NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
    void *ptr = dlsym(handle_, symbol);
    NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
    return ptr;
#endif  // _WIN32 or _WIN64 or __WINDOW__
  }

 private:
  void *handle_ = nullptr;
};

76
void swap(Library &first, Library &second) noexcept {
Tim Moon's avatar
Tim Moon committed
77
78
79
80
81
  using std::swap;
  swap(first.handle_, second.handle_);
}

/*! \brief Lazily-initialized shared library for CUDA driver */
82
Library &cuda_driver_lib() {
Tim Moon's avatar
Tim Moon committed
83
84
85
86
87
88
89
90
91
92
93
94
95
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
  constexpr char lib_name[] = "nvcuda.dll";
#else
  constexpr char lib_name[] = "libcuda.so.1";
#endif
  static Library lib(lib_name);
  return lib;
}

}  // namespace

namespace cuda_driver {

96
97
98
99
100
101
102
103
void *get_symbol(const char *symbol) {
  void *entry_point;
  cudaDriverEntryPointQueryResult driver_result;
  NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
  NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
             "Could not find CUDA driver entry point for ", symbol);
  return entry_point;
}
Tim Moon's avatar
Tim Moon committed
104
105
106
107

}  // namespace cuda_driver

}  // namespace transformer_engine