rtc.h 6.54 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_

10
11
12
13
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>

Tim Moon's avatar
Tim Moon committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace rtc {

/*! \brief Whether NVRTC support is enabled
 *
 * NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in
 * the environment.
 */
bool is_enabled();

/*! \brief Wrapper class for a runtime-compiled CUDA kernel */
class Kernel {
 public:
  Kernel(std::string mangled_name, std::string compiled_code);
  ~Kernel();
41
42
43
44
  Kernel(const Kernel &) = delete;  // move-only
  Kernel(Kernel &&) noexcept;
  Kernel &operator=(Kernel) noexcept;
  friend void swap(Kernel &first, Kernel &second) noexcept;
Tim Moon's avatar
Tim Moon committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

  /*! \brief Launch CUDA kernel
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   *
   * \param[in] device_id        CUDA device
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
60
61
62
63
64
65
  void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
              unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
    void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
    NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
                                grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
                                static_cast<CUstream>(stream), arg_ptrs, nullptr);
Tim Moon's avatar
Tim Moon committed
66
67
68
69
70
71
72
73
74
  }

  /*! \brief CUDA function for given CUDA device
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   */
  CUfunction get_function(int device_id);

75
76
77
78
79
80
  /*! \brief Sets the preferred cache configuration for a function
   *
   * Wrapper of the CUDA Driver API function "cuFuncSetCacheConfig"
   */
  void set_function_cache_config(int device_id, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
 private:
  /*! \brief Mangled function name */
  std::string mangled_name_;
  /*! \brief  Compiled assembly, either in PTX or cubin format */
  std::string compiled_code_;
  /*! CUDA module for each CUDA device */
  std::vector<CUmodule> modules_;
  /*! CUDA function for each CUDA device */
  std::vector<CUfunction> functions_;

  /*! Flags for thread-safe kernel initialization */
  std::unique_ptr<std::vector<std::once_flag>> init_flags_;

  /*! \brief Uninitialized CUDA module */
  static constexpr CUmodule null_module = static_cast<CUmodule>(nullptr);
  /*! Uninitialized CUDA function */
  static constexpr CUfunction null_function = static_cast<CUfunction>(nullptr);
};

/*! \brief Singleton class to manage runtime-compiled CUDA kernels */
class KernelManager {
 public:
  /*! \brief Get singleton instance */
104
  static KernelManager &instance();
Tim Moon's avatar
Tim Moon committed
105
106
107
108
109
110
111
112
113
114
115

  /*! \brief Compile CUDA kernel for current CUDA device
   *
   * The compiled kernel is cached and made available for launching.
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] kernel_name  Kernel name within source code
   * \param[in] code         Kernel source code
   * \param[in] filename     Path to associate with source code,
   *                         primarily for debugging
   */
116
117
  void compile(const std::string &kernel_label, const std::string &kernel_name,
               const std::string &code, const std::string &filename);
Tim Moon's avatar
Tim Moon committed
118
119
120
121
122
123
124
125

  /*! \brief Whether CUDA kernel has been compiled for CUDA device
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)

   * \return Whether kernel has been compiled
   */
126
  bool is_compiled(const std::string &kernel_label, int device_id = -1) const;
Tim Moon's avatar
Tim Moon committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140

  /*! \brief Launch CUDA kernel on current CUDA device
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
141
142
  void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim,
              unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
Tim Moon's avatar
Tim Moon committed
143
144
    const int device_id = cuda::current_device();
    const auto key = get_kernel_cache_key(kernel_label, device_id);
145
146
    NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation");
    kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream,
Tim Moon's avatar
Tim Moon committed
147
148
149
                                 std::forward<ArgTs>(args)...);
  }

150
151
152
153
154
155
156
157
158
  /*! \brief Sets the preferred cache configuration for a function in the context
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] cache_config     Prefered cache configuration
   */
  void set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
159
160
161
162
163
164
165
166
 private:
  /*! \brief Compiled kernels */
  std::unordered_map<std::string, Kernel> kernel_cache_;
  /*! \brief Mutex for thread-safe compilation */
  std::mutex lock_;

  KernelManager() = default;
  ~KernelManager() = default;
167
168
  KernelManager(const KernelManager &) = delete;
  KernelManager &operator=(const KernelManager &) = delete;
Tim Moon's avatar
Tim Moon committed
169
170
171
172
173
174
175
176

  /*! \brief Construct key for kernel cache
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)
   *
   * \return Key for kernel cache
   */
177
  std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const;
Tim Moon's avatar
Tim Moon committed
178
179
180
181
182
183
184
};

}  // namespace rtc

}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_