rtc.cpp 9.09 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
#include "../util/rtc.h"

Tim Moon's avatar
Tim Moon committed
9
10
11
12
13
#include <cstdlib>
#include <iostream>
#include <utility>

#include "../common.h"
yuguo's avatar
yuguo committed
14
15
16
#ifdef USE_ROCM
#include "../util/hip_driver.h"
#else
Tim Moon's avatar
Tim Moon committed
17
#include "../util/cuda_driver.h"
yuguo's avatar
yuguo committed
18
#endif
Tim Moon's avatar
Tim Moon committed
19
20
21
22
23
24
25
26
27
28
#include "../util/string.h"
#include "../util/system.h"

namespace transformer_engine {

namespace rtc {

namespace {

// Strings with headers for RTC kernels
29
#include "string_code_util_math_h.h"
30
#include "string_code_utils_cuh.h"
Tim Moon's avatar
Tim Moon committed
31

yuguo's avatar
yuguo committed
32
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/*! \brief Latest compute capability that NVRTC supports
 *
 * \return Compute capability as int. Last digit is minor revision,
 *         remaining digits are major revision.
 */
inline int max_supported_sm_arch() {
  static int arch_ = -1;
  if (arch_ < 0) {
    int num_archs = 0;
    NVTE_CHECK_NVRTC(nvrtcGetNumSupportedArchs(&num_archs));
    NVTE_CHECK(num_archs > 0, "Could not determine SM archs that NVRTC supports");
    std::vector<int> archs(num_archs);
    NVTE_CHECK_NVRTC(nvrtcGetSupportedArchs(archs.data()));
    arch_ = archs.back();
  }
  return arch_;
}
yuguo's avatar
yuguo committed
50
#endif // USE_ROCM
Tim Moon's avatar
Tim Moon committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64

}  // namespace

bool is_enabled() {
  static bool is_enabled_ = false;
  static bool need_to_check_env = true;
  if (need_to_check_env) {
    is_enabled_ = !getenv<bool>("NVTE_DISABLE_NVRTC");
    need_to_check_env = false;
  }
  return is_enabled_;
}

Kernel::Kernel(std::string mangled_name, std::string compiled_code)
65
66
67
68
69
    : mangled_name_{std::move(mangled_name)},
      compiled_code_{std::move(compiled_code)},
      modules_(cuda::num_devices(), null_module),
      functions_(cuda::num_devices(), null_function),
      init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {}
Tim Moon's avatar
Tim Moon committed
70
71

Kernel::~Kernel() {
72
  for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
Tim Moon's avatar
Tim Moon committed
73
74
    // Unload CUDA modules if needed
    if (modules_[device_id] != null_module) {
yuguo's avatar
yuguo committed
75
76
77
#ifdef USE_ROCM
      (void)cuda_driver::call("hipModuleUnload", modules_[device_id]);
#else
Tim Moon's avatar
Tim Moon committed
78
79
      CUdevice device;
      CUcontext context;
80
      if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
Tim Moon's avatar
Tim Moon committed
81
82
        continue;
      }
83
      if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) != CUDA_SUCCESS) {
Tim Moon's avatar
Tim Moon committed
84
85
        continue;
      }
86
      if (cuda_driver::call("cuCtxSetCurrent", context) != CUDA_SUCCESS) {
87
88
        continue;
      }
Tim Moon's avatar
Tim Moon committed
89
90
      cuda_driver::call("cuModuleUnload", modules_[device_id]);
      cuda_driver::call("cuDevicePrimaryCtxRelease", device);
yuguo's avatar
yuguo committed
91
#endif // USE_ROCM
Tim Moon's avatar
Tim Moon committed
92
93
94
95
    }
  }
}

96
Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); }
Tim Moon's avatar
Tim Moon committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

Kernel& Kernel::operator=(Kernel other) noexcept {
  // Copy-and-swap idiom
  swap(*this, other);
  return *this;
}

void swap(Kernel& first, Kernel& second) noexcept {
  using std::swap;
  swap(first.mangled_name_, second.mangled_name_);
  swap(first.compiled_code_, second.compiled_code_);
  swap(first.modules_, second.modules_);
  swap(first.functions_, second.functions_);
  swap(first.init_flags_, second.init_flags_);
}

CUfunction Kernel::get_function(int device_id) {
  // Load kernel on device if needed
115
  auto load_on_device = [&]() {
Tim Moon's avatar
Tim Moon committed
116
117
118
119
120
    // Set driver context to proper device
    CUdevice device;
    CUcontext context;
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id);
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
121
    NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
Tim Moon's avatar
Tim Moon committed
122
123

    // Load function into driver context
124
125
126
127
128
    NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(),
                                0,         // numOptions
                                nullptr,   // options
                                nullptr);  // optionValues
    NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, &functions_[device_id], modules_[device_id],
Tim Moon's avatar
Tim Moon committed
129
130
131
132
133
134
135
136
137
138
139
                                mangled_name_.c_str());

    // Reset driver context
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
  };
  std::call_once(init_flags_->at(device_id), load_on_device);

  // Return CUDA function
  return functions_[device_id];
}

140
141
142
143
void Kernel::set_function_cache_config(int device_id, CUfunc_cache cache_config) {
  NVTE_CALL_CHECK_CUDA_DRIVER(cuFuncSetCacheConfig, get_function(device_id), cache_config);
}

Tim Moon's avatar
Tim Moon committed
144
145
146
147
148
149
KernelManager& KernelManager::instance() {
  NVTE_CHECK(is_enabled(), "NVRTC support is not enabled");
  static KernelManager instance_;
  return instance_;
}

150
151
void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name,
                            const std::string& code, const std::string& filename) {
Tim Moon's avatar
Tim Moon committed
152
153
154
155
  std::lock_guard<std::mutex> lock_guard_(lock_);

  // Choose whether to compile to PTX or cubin
  const int device_id = cuda::current_device();
yuguo's avatar
yuguo committed
156
#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
157
158
159
  const int sm_arch_ = cuda::sm_arch(device_id);
  const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
  const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);
yuguo's avatar
yuguo committed
160
#endif // USE_ROCM
Tim Moon's avatar
Tim Moon committed
161
162
163
164

  // Compilation flags
  std::vector<std::string> opts = {
#if NDEBUG == 0
165
      "-G",
Tim Moon's avatar
Tim Moon committed
166
#endif
167
      "--std=c++17"};
yuguo's avatar
yuguo committed
168
169

#ifndef USE_ROCM
Tim Moon's avatar
Tim Moon committed
170
171
172
173
174
175
  if (compile_ptx) {
    opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
  } else {
    opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch));
  }
  opts.push_back(concat_strings("-I", cuda::include_directory(true)));
yuguo's avatar
yuguo committed
176
#endif //USE_ROCM
Tim Moon's avatar
Tim Moon committed
177
178
179
180
181
182
183
  std::vector<const char*> opts_ptrs;
  for (const auto& opt : opts) {
    opts_ptrs.push_back(opt.c_str());
  }

  // Compile source
  nvrtcProgram program;
yuguo's avatar
yuguo committed
184
#ifdef USE_ROCM
wenjh's avatar
wenjh committed
185
186
187
  constexpr int num_headers = 2;
  const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
  const char* include_names[num_headers] = {"utils_hip.cuh", "util/math.h"};
yuguo's avatar
yuguo committed
188
#else
189
190
191
  constexpr int num_headers = 2;
  constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
  constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
yuguo's avatar
yuguo committed
192
#endif // USE_ROCM
193
194
  NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
                                      headers, include_names));
Tim Moon's avatar
Tim Moon committed
195
  NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
196
197
  const nvrtcResult compile_result =
      nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data());
Tim Moon's avatar
Tim Moon committed
198
199
  if (compile_result != NVRTC_SUCCESS) {
    // Display log if compilation failed
200
    std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n");
Tim Moon's avatar
Tim Moon committed
201
202
203
204
205
206
207
208
209
210
211
    const size_t log_offset = log.size();
    size_t log_size;
    NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
    log.resize(log_offset + log_size);
    NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset]));
    log.back() = '\n';
    std::cerr << log;
    NVTE_CHECK_NVRTC(compile_result);
  }

  // Get mangled function name
212
213
  const char* mangled_name;
  NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name));
Tim Moon's avatar
Tim Moon committed
214
215
216

  // Get compiled code
  std::string compiled_code;
yuguo's avatar
yuguo committed
217
218
219
220
221
222
223
224
#ifdef USE_ROCM
  {
    size_t compiled_size;
    NVTE_CHECK_NVRTC(hiprtcGetCodeSize(program, &compiled_size));
    compiled_code.resize(compiled_size);
    NVTE_CHECK_NVRTC(hiprtcGetCode(program, compiled_code.data()));
  }
#else
Tim Moon's avatar
Tim Moon committed
225
226
227
228
229
230
231
232
233
234
235
  if (compile_ptx) {
    size_t compiled_size;
    NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size));
    compiled_code.resize(compiled_size);
    NVTE_CHECK_NVRTC(nvrtcGetPTX(program, compiled_code.data()));
  } else {
    size_t compiled_size;
    NVTE_CHECK_NVRTC(nvrtcGetCUBINSize(program, &compiled_size));
    compiled_code.resize(compiled_size);
    NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data()));
  }
yuguo's avatar
yuguo committed
236
#endif //USE_ROCM
Tim Moon's avatar
Tim Moon committed
237
238
239
240
241
242
243
244
245
246

  // Cache compiled code
  const auto key = get_kernel_cache_key(kernel_label, device_id);
  kernel_cache_.insert({key, Kernel(mangled_name, std::move(compiled_code))});
  kernel_cache_.at(key).get_function(device_id);  // Make sure kernel is available on device

  // Clean up
  NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
}

247
void KernelManager::set_cache_config(const std::string& kernel_label, CUfunc_cache cache_config) {
248
249
  const int device_id = cuda::current_device();
  const auto key = get_kernel_cache_key(kernel_label, device_id);
250
  NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation");
251
252
253
  kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
}

254
bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const {
Tim Moon's avatar
Tim Moon committed
255
256
257
258
  const auto key = get_kernel_cache_key(kernel_label, device_id);
  return kernel_cache_.count(key) > 0;
}

259
std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
Tim Moon's avatar
Tim Moon committed
260
                                                int device_id) const {
yuguo's avatar
yuguo committed
261
262
263
#ifdef USE_ROCM
  return concat_strings(cuda::sm_arch_name(device_id), ",", kernel_label);
#else
Tim Moon's avatar
Tim Moon committed
264
  return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
yuguo's avatar
yuguo committed
265
#endif
Tim Moon's avatar
Tim Moon committed
266
267
268
269
270
}

}  // namespace rtc

}  // namespace transformer_engine