rtc.cpp 8.81 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cstdlib>
#include <iostream>
#include <utility>

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/string.h"
#include "../util/system.h"

#include "../util/rtc.h"

namespace transformer_engine {

namespace rtc {

namespace {

// Strings with headers for RTC kernels
#include "string_code_utils_cuh.h"
26
#include "string_code_util_math_h.h"
Tim Moon's avatar
Tim Moon committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

/*! \brief Latest compute capability that NVRTC supports
 *
 * \return Compute capability as int. Last digit is minor revision,
 *         remaining digits are major revision.
 */
inline int max_supported_sm_arch() {
  static int arch_ = -1;
  if (arch_ < 0) {
    int num_archs = 0;
    NVTE_CHECK_NVRTC(nvrtcGetNumSupportedArchs(&num_archs));
    NVTE_CHECK(num_archs > 0, "Could not determine SM archs that NVRTC supports");
    std::vector<int> archs(num_archs);
    NVTE_CHECK_NVRTC(nvrtcGetSupportedArchs(archs.data()));
    arch_ = archs.back();
  }
  return arch_;
}

}  // namespace

bool is_enabled() {
  static bool is_enabled_ = false;
  static bool need_to_check_env = true;
  if (need_to_check_env) {
    is_enabled_ = !getenv<bool>("NVTE_DISABLE_NVRTC");
    need_to_check_env = false;
  }
  return is_enabled_;
}

Kernel::Kernel(std::string mangled_name, std::string compiled_code)
  : mangled_name_{std::move(mangled_name)}
  , compiled_code_{std::move(compiled_code)}
  , modules_(cuda::num_devices(), null_module)
  , functions_(cuda::num_devices(), null_function)
  , init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {
}

Kernel::~Kernel() {
  for (int device_id=0; device_id<static_cast<int>(modules_.size()); ++device_id) {
    // Unload CUDA modules if needed
    if (modules_[device_id] != null_module) {
      CUdevice device;
      CUcontext context;
      if (cuda_driver::call("cuDeviceGet", &device, device_id)
          != CUDA_SUCCESS) {
        continue;
      }
      if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device)
          != CUDA_SUCCESS) {
        continue;
      }
80
81
82
83
      if (cuda_driver::call("cuCtxSetCurrent", context)
          != CUDA_SUCCESS) {
        continue;
      }
Tim Moon's avatar
Tim Moon committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
      cuda_driver::call("cuModuleUnload", modules_[device_id]);
      cuda_driver::call("cuDevicePrimaryCtxRelease", device);
    }
  }
}

Kernel::Kernel(Kernel&& other) noexcept {
  swap(*this, other);
}

Kernel& Kernel::operator=(Kernel other) noexcept {
  // Copy-and-swap idiom
  swap(*this, other);
  return *this;
}

void swap(Kernel& first, Kernel& second) noexcept {
  using std::swap;
  swap(first.mangled_name_, second.mangled_name_);
  swap(first.compiled_code_, second.compiled_code_);
  swap(first.modules_, second.modules_);
  swap(first.functions_, second.functions_);
  swap(first.init_flags_, second.init_flags_);
}

CUfunction Kernel::get_function(int device_id) {
  // Load kernel on device if needed
  auto load_on_device = [&] () {
    // Set driver context to proper device
    CUdevice device;
    CUcontext context;
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id);
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
117
    NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
Tim Moon's avatar
Tim Moon committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    // Load function into driver context
    NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx,
                                &modules_[device_id],
                                compiled_code_.c_str(),
                                0,          // numOptions
                                nullptr,    // options
                                nullptr);   // optionValues
    NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction,
                                &functions_[device_id],
                                modules_[device_id],
                                mangled_name_.c_str());

    // Reset driver context
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
  };
  std::call_once(init_flags_->at(device_id), load_on_device);

  // Return CUDA function
  return functions_[device_id];
}

140
141
142
143
void Kernel::set_function_cache_config(int device_id, CUfunc_cache cache_config) {
  NVTE_CALL_CHECK_CUDA_DRIVER(cuFuncSetCacheConfig, get_function(device_id), cache_config);
}

Tim Moon's avatar
Tim Moon committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
KernelManager& KernelManager::instance() {
  NVTE_CHECK(is_enabled(), "NVRTC support is not enabled");
  static KernelManager instance_;
  return instance_;
}

void KernelManager::compile(const std::string &kernel_label,
                            const std::string &kernel_name,
                            const std::string &code,
                            const std::string &filename) {
  std::lock_guard<std::mutex> lock_guard_(lock_);

  // Choose whether to compile to PTX or cubin
  const int device_id = cuda::current_device();
  const int sm_arch_ = cuda::sm_arch(device_id);
  const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
  const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch);

  // Compilation flags
  std::vector<std::string> opts = {
#if NDEBUG == 0
    "-G",
#endif
    "--std=c++17"};
  if (compile_ptx) {
    opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
  } else {
    opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch));
  }
  opts.push_back(concat_strings("-I", cuda::include_directory(true)));
  std::vector<const char*> opts_ptrs;
  for (const auto& opt : opts) {
    opts_ptrs.push_back(opt.c_str());
  }

  // Compile source
  nvrtcProgram program;
181
182
183
  constexpr int num_headers = 2;
  constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
  constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
Tim Moon's avatar
Tim Moon committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
  NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program,
                                      code.c_str(),
                                      filename.c_str(),
                                      num_headers,
                                      headers,
                                      include_names));
  NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
  const nvrtcResult compile_result = nvrtcCompileProgram(program,
                                                         opts_ptrs.size(),
                                                         opts_ptrs.data());
  if (compile_result != NVRTC_SUCCESS) {
    // Display log if compilation failed
    std::string log = concat_strings("NVRTC compilation log for ",
                                     filename, ":\n");
    const size_t log_offset = log.size();
    size_t log_size;
    NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
    log.resize(log_offset + log_size);
    NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset]));
    log.back() = '\n';
    std::cerr << log;
    NVTE_CHECK_NVRTC(compile_result);
  }

  // Get mangled function name
  const char *mangled_name;
  NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program,
                                       kernel_name.c_str(),
                                       &mangled_name));

  // Get compiled code
  std::string compiled_code;
  if (compile_ptx) {
    size_t compiled_size;
    NVTE_CHECK_NVRTC(nvrtcGetPTXSize(program, &compiled_size));
    compiled_code.resize(compiled_size);
    NVTE_CHECK_NVRTC(nvrtcGetPTX(program, compiled_code.data()));
  } else {
    size_t compiled_size;
    NVTE_CHECK_NVRTC(nvrtcGetCUBINSize(program, &compiled_size));
    compiled_code.resize(compiled_size);
    NVTE_CHECK_NVRTC(nvrtcGetCUBIN(program, compiled_code.data()));
  }

  // Cache compiled code
  const auto key = get_kernel_cache_key(kernel_label, device_id);
  kernel_cache_.insert({key, Kernel(mangled_name, std::move(compiled_code))});
  kernel_cache_.at(key).get_function(device_id);  // Make sure kernel is available on device

  // Clean up
  NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
}

237
238
239
240
241
242
243
244
void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) {
  const int device_id = cuda::current_device();
  const auto key = get_kernel_cache_key(kernel_label, device_id);
  NVTE_CHECK(kernel_cache_.count(key) > 0,
             "Attempted to configure RTC kernel before compilation");
  kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
}

Tim Moon's avatar
Tim Moon committed
245
246
247
248
249
250
251
252
253
254
255
256
257
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const {
  const auto key = get_kernel_cache_key(kernel_label, device_id);
  return kernel_cache_.count(key) > 0;
}

std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label,
                                                int device_id) const {
  return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
}

}  // namespace rtc

}  // namespace transformer_engine