"cacheflow/frontend/simple_frontend.py" did not exist on "721fa3df155e5649bbe2188517594f24f4e63523"
rtc.h 6.64 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_

10
11
12
13
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <nvrtc.h>

Tim Moon's avatar
Tim Moon committed
14
15
16
17
18
19
20
21
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../common.h"
yuguo's avatar
yuguo committed
22
23
24
25
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_driver.h"
#include "../util/hip_runtime.h"
#else
Tim Moon's avatar
Tim Moon committed
26
27
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
yuguo's avatar
yuguo committed
28
#endif
Tim Moon's avatar
Tim Moon committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

namespace transformer_engine {

namespace rtc {

/*! \brief Whether NVRTC support is enabled
 *
 * NVRTC support can be disabled by setting NVTE_DISABLE_NVRTC=1 in
 * the environment.
 */
bool is_enabled();

/*! \brief Wrapper class for a runtime-compiled CUDA kernel */
class Kernel {
 public:
  Kernel(std::string mangled_name, std::string compiled_code);
  ~Kernel();
46
47
48
49
  Kernel(const Kernel &) = delete;  // move-only
  Kernel(Kernel &&) noexcept;
  Kernel &operator=(Kernel) noexcept;
  friend void swap(Kernel &first, Kernel &second) noexcept;
Tim Moon's avatar
Tim Moon committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

  /*! \brief Launch CUDA kernel
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   *
   * \param[in] device_id        CUDA device
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
65
66
67
68
69
70
  void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
              unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
    void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
    NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
                                grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
                                static_cast<CUstream>(stream), arg_ptrs, nullptr);
Tim Moon's avatar
Tim Moon committed
71
72
73
74
75
76
77
78
79
  }

  /*! \brief CUDA function for given CUDA device
   *
   * Loads the kernel into the device the first time the device is
   * accessed.
   */
  CUfunction get_function(int device_id);

80
81
82
83
84
85
  /*! \brief Sets the preferred cache configuration for a function
   *
   * Wrapper of the CUDA Driver API function "cuFuncSetCacheConfig"
   */
  void set_function_cache_config(int device_id, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
 private:
  /*! \brief Mangled function name */
  std::string mangled_name_;
  /*! \brief  Compiled assembly, either in PTX or cubin format */
  std::string compiled_code_;
  /*! CUDA module for each CUDA device */
  std::vector<CUmodule> modules_;
  /*! CUDA function for each CUDA device */
  std::vector<CUfunction> functions_;

  /*! Flags for thread-safe kernel initialization */
  std::unique_ptr<std::vector<std::once_flag>> init_flags_;

  /*! \brief Uninitialized CUDA module */
  static constexpr CUmodule null_module = static_cast<CUmodule>(nullptr);
  /*! Uninitialized CUDA function */
  static constexpr CUfunction null_function = static_cast<CUfunction>(nullptr);
};

/*! \brief Singleton class to manage runtime-compiled CUDA kernels */
class KernelManager {
 public:
  /*! \brief Get singleton instance */
109
  static KernelManager &instance();
Tim Moon's avatar
Tim Moon committed
110
111
112
113
114
115
116
117
118
119
120

  /*! \brief Compile CUDA kernel for current CUDA device
   *
   * The compiled kernel is cached and made available for launching.
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] kernel_name  Kernel name within source code
   * \param[in] code         Kernel source code
   * \param[in] filename     Path to associate with source code,
   *                         primarily for debugging
   */
121
122
  void compile(const std::string &kernel_label, const std::string &kernel_name,
               const std::string &code, const std::string &filename);
Tim Moon's avatar
Tim Moon committed
123
124
125
126
127
128
129
130

  /*! \brief Whether CUDA kernel has been compiled for CUDA device
   *
   * \param[in] kernel_label Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)

   * \return Whether kernel has been compiled
   */
131
  bool is_compiled(const std::string &kernel_label, int device_id = -1) const;
Tim Moon's avatar
Tim Moon committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145

  /*! \brief Launch CUDA kernel on current CUDA device
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] grid_dim         Grid dimensions in blocks
   * \param[in] block_dim        Thread block dimensions
   * \param[in] shared_mem_bytes Dynamic shared-memory size per thread block in
   *                             bytes
   * \param[in] stream           CUDA stream
   * \param[in] args             Kernel arguments
   */
  template <typename... ArgTs>
146
147
  void launch(const std::string &kernel_label, const dim3 grid_dim, const dim3 block_dim,
              unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
Tim Moon's avatar
Tim Moon committed
148
149
    const int device_id = cuda::current_device();
    const auto key = get_kernel_cache_key(kernel_label, device_id);
150
151
    NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to launch RTC kernel before compilation");
    kernel_cache_.at(key).launch(device_id, grid_dim, block_dim, shared_mem_bytes, stream,
Tim Moon's avatar
Tim Moon committed
152
153
154
                                 std::forward<ArgTs>(args)...);
  }

155
156
157
158
159
160
161
162
163
  /*! \brief Sets the preferred cache configuration for a function in the context
   *
   * Assumes the kernel has already been compiled.
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] cache_config     Prefered cache configuration
   */
  void set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config);

Tim Moon's avatar
Tim Moon committed
164
165
166
167
168
169
170
171
 private:
  /*! \brief Compiled kernels */
  std::unordered_map<std::string, Kernel> kernel_cache_;
  /*! \brief Mutex for thread-safe compilation */
  std::mutex lock_;

  KernelManager() = default;
  ~KernelManager() = default;
172
173
  KernelManager(const KernelManager &) = delete;
  KernelManager &operator=(const KernelManager &) = delete;
Tim Moon's avatar
Tim Moon committed
174
175
176
177
178
179
180
181

  /*! \brief Construct key for kernel cache
   *
   * \param[in] kernel_label     Unique identifying string for kernel
   * \param[in] device_id    CUDA device (default is current device)
   *
   * \return Key for kernel cache
   */
182
  std::string get_kernel_cache_key(const std::string &kernel_label, int device_id) const;
Tim Moon's avatar
Tim Moon committed
183
184
185
186
187
188
189
};

}  // namespace rtc

}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_RTC_H_