rtc.h 7.15 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_

#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"

namespace transformer_engine {

namespace rtc {

/*! \brief Whether NVRTC support is enabled
 *
 * NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in
 * the environment.
 */
bool is_enabled();

/*! \brief Wrapper class for a runtime-compiled CUDA kernel */
class Kernel {
 public:
  Kernel(std::string mangled_name, std::string compiled_code);
  ~Kernel();
  Kernel(const Kernel&) = delete;  // move-only
  Kernel(Kernel&&) noexcept;
  Kernel& operator=(Kernel) noexcept;
  friend void swap(Kernel& first, Kernel& second) noexcept;

  /*! \brief Launch CUDA kernel
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   *
   * \param[in] device_id        CUDA device
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
  void launch(int device_id,
              const dim3 grid_dim,
              const dim3 block_dim,
              unsigned int shared_mem_bytes,
              cudaStream_t stream,
              ArgTs &&... args) {
    void* arg_ptrs[] = { const_cast<void*>(static_cast<const void*>(&args))... };
    NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel,
                                get_function(device_id),
                                grid_dim.x,
                                grid_dim.y,
                                grid_dim.z,
                                block_dim.x,
                                block_dim.y,
                                block_dim.z,
                                shared_mem_bytes,
                                static_cast<CUstream>(stream),
                                arg_ptrs,
                                nullptr);
  }

  /*! \brief CUDA function for given CUDA device
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   */
  CUfunction get_function(int device_id);

88
89
90
91
92
93
  /*! \brief Sets the preferred cache configuration for a function
   *
   * Wrapper of the CUDA Driver API function "cuFuncSetCacheConfig"
   */
  void set_function_cache_config(int device_id, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
 private:
  /*! \brief Mangled function name */
  std::string mangled_name_;
  /*! \brief  Compiled assembly, either in PTX or cubin format */
  std::string compiled_code_;
  /*! CUDA module for each CUDA device */
  std::vector<CUmodule> modules_;
  /*! CUDA function for each CUDA device */
  std::vector<CUfunction> functions_;

  /*! Flags for thread-safe kernel initialization */
  std::unique_ptr<std::vector<std::once_flag>> init_flags_;

  /*! \brief Uninitialized CUDA module */
  static constexpr CUmodule null_module = static_cast<CUmodule>(nullptr);
  /*! Uninitialized CUDA function */
  static constexpr CUfunction null_function = static_cast<CUfunction>(nullptr);
};

/*! \brief Singleton class to manage runtime-compiled CUDA kernels */
class KernelManager {
 public:
  /*! \brief Get singleton instance */
  static KernelManager& instance();

  /*! \brief Compile CUDA kernel for current CUDA device
   *
   * The compiled kernel is cached and made available for launching.
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] kernel_name  Kernel name within source code
   * \param[in] code         Kernel source code
   * \param[in] filename     Path to associate with source code,
   *                         primarily for debugging
   */
  void compile(const std::string &kernel_label,
               const std::string &kernel_name,
               const std::string &code,
               const std::string &filename);

  /*! \brief Whether CUDA kernel has been compiled for CUDA device
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)

   * \return Whether kernel has been compiled
   */
  bool is_compiled(const std::string &kernel_label,
                   int device_id = -1) const;

  /*! \brief Launch CUDA kernel on current CUDA device
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
  void launch(const std::string &kernel_label,
              const dim3 grid_dim,
              const dim3 block_dim,
              unsigned int shared_mem_bytes,
              cudaStream_t stream,
              ArgTs &&... args) {
    const int device_id = cuda::current_device();
    const auto key = get_kernel_cache_key(kernel_label, device_id);
    NVTE_CHECK(kernel_cache_.count(key) > 0,
               "Attempted to launch RTC kernel before compilation");
    kernel_cache_.at(key).launch(device_id,
                                 grid_dim,
                                 block_dim,
                                 shared_mem_bytes,
                                 stream,
                                 std::forward<ArgTs>(args)...);
  }

175
176
177
178
179
180
181
182
183
  /*! \brief Sets the preferred cache configuration for a function in the context
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] cache_config     Prefered cache configuration
   */
  void set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
 private:
  /*! \brief Compiled kernels */
  std::unordered_map<std::string, Kernel> kernel_cache_;
  /*! \brief Mutex for thread-safe compilation */
  std::mutex lock_;

  KernelManager() = default;
  ~KernelManager() = default;
  KernelManager(const KernelManager&) = delete;
  KernelManager& operator=(const KernelManager&) = delete;

  /*! \brief Construct key for kernel cache
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)
   *
   * \return Key for kernel cache
   */
  std::string get_kernel_cache_key(const std::string &kernel_label,
                                   int device_id) const;
};

}  // namespace rtc

}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_