cuda_runtime.cpp 7.27 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
#include "../util/cuda_runtime.h"

9
10
#include <cublasLt.h>

Tim Moon's avatar
Tim Moon committed
11
12
13
14
15
16
#include <filesystem>
#include <mutex>

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/system.h"
17
#include "common/util/cuda_runtime.h"
Tim Moon's avatar
Tim Moon committed
18
19
20
21
22

namespace transformer_engine {

namespace cuda {

yuguo's avatar
yuguo committed
23
#ifndef __HIP_PLATFORM_AMD__
Tim Moon's avatar
Tim Moon committed
24
25
26
27
28
29
namespace {

// String with build-time CUDA include path
#include "string_path_cuda_include.h"

}  // namespace
yuguo's avatar
yuguo committed
30
#endif // __HIP_PLATFORM_AMD__
Tim Moon's avatar
Tim Moon committed
31
32

int num_devices() {
33
  auto query_num_devices = []() -> int {
Tim Moon's avatar
Tim Moon committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    int count;
    NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
    return count;
  };
  static int num_devices_ = query_num_devices();
  return num_devices_;
}

int current_device() {
  // Return 0 if CUDA context is not initialized
  CUcontext context;
  NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
  if (context == nullptr) {
    return 0;
  }

  // Query device from CUDA runtime
  int device_id;
  NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
  return device_id;
}

int sm_arch(int device_id) {
  static std::vector<int> cache(num_devices(), -1);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
63
  auto init = [&]() {
Tim Moon's avatar
Tim Moon committed
64
65
    cudaDeviceProp prop;
    NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
66
    cache[device_id] = 10 * prop.major + prop.minor;
Tim Moon's avatar
Tim Moon committed
67
68
69
70
71
72
73
74
75
76
77
78
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
}

int sm_count(int device_id) {
  static std::vector<int> cache(num_devices(), -1);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
79
  auto init = [&]() {
Tim Moon's avatar
Tim Moon committed
80
81
82
83
84
85
86
87
    cudaDeviceProp prop;
    NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
    cache[device_id] = prop.multiProcessorCount;
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
}

yuguo's avatar
yuguo committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#ifdef __HIP_PLATFORM_AMD__
const std::string &sm_arch_name(int device_id) {
  static std::vector<std::string> cache(num_devices(), "");
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid HIP device ID");
  auto init = [&] () {
    cudaDeviceProp prop;
    NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
    cache[device_id] = prop.gcnArchName;
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
}
#endif // __HIP_PLATFORM_AMD__

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
  static std::vector<std::pair<int, int>> cache(num_devices());
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
  auto init = [&]() {
    int ori_dev = current_device();
    if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id));
    int min_pri, max_pri;
    NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri));
    if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev));
    cache[device_id] = std::make_pair(min_pri, max_pri);
  };
  std::call_once(flags[device_id], init);
  *low_priority = cache[device_id].first;
  *high_priority = cache[device_id].second;
}

126
127
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
128
  // NOTE: This needs to be guarded at compile-time and run-time because the
129
  //       CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
130
131
132
  if (cudart_version() < 12010) {
    return false;
  }
133
134
135
136
137
138
139
140
141
  static std::vector<bool> cache(num_devices(), false);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
  auto init = [&]() {
    CUdevice cudev;
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
142
143
144
145
146
147
148
149
150
    // Multicast support requires both CUDA12.1 UMD + KMD
    int result = 0;
    // Check if KMD >= 12.1
    int driver_version;
    NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
    if (driver_version >= 12010) {
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
                                  CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
    }
151
152
153
154
155
156
157
158
159
    cache[device_id] = static_cast<bool>(result);
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
#else
  return false;
#endif
}

yuguo's avatar
yuguo committed
160
#ifndef __HIP_PLATFORM_AMD__
Tim Moon's avatar
Tim Moon committed
161
162
163
164
165
166
167
168
169
170
171
const std::string &include_directory(bool required) {
  static std::string path;

  // Update cached path if needed
  static bool need_to_check_env = true;
  if (path.empty() && required) {
    need_to_check_env = true;
  }
  if (need_to_check_env) {
    // Search for CUDA headers in common paths
    using Path = std::filesystem::path;
172
173
174
175
176
    std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
                                                              {"CUDA_HOME", ""},
                                                              {"CUDA_DIR", ""},
                                                              {"", string_path_cuda_include},
                                                              {"", "/usr/local/cuda"}};
Tim Moon's avatar
Tim Moon committed
177
178
    for (auto &[env, p] : search_paths) {
      if (p.empty()) {
179
        p = getenv<Path>(env.c_str());
Tim Moon's avatar
Tim Moon committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
      }
      if (!p.empty()) {
        if (file_exists(p / "cuda_runtime.h")) {
          path = p;
          break;
        }
        if (file_exists(p / "include" / "cuda_runtime.h")) {
          path = p / "include";
          break;
        }
      }
    }

    // Throw exception if path is required but not found
    if (path.empty() && required) {
      std::string message;
      message.reserve(2048);
      message += "Could not find cuda_runtime.h in";
      bool is_first = true;
      for (const auto &[env, p] : search_paths) {
        message += is_first ? " " : ", ";
        is_first = false;
        if (!env.empty()) {
          message += env;
          message += "=";
        }
        if (p.empty()) {
          message += "<unset>";
        } else {
          message += p;
        }
      }
212
213
214
215
216
      message +=
          (". "
           "Specify path to CUDA Toolkit headers "
           "with NVTE_CUDA_INCLUDE_DIR "
           "or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
Tim Moon's avatar
Tim Moon committed
217
218
219
220
221
222
223
224
      NVTE_ERROR(message);
    }
    need_to_check_env = false;
  }

  // Return cached path
  return path;
}
yuguo's avatar
yuguo committed
225
#endif // __HIP_PLATFORM_AMD__
Tim Moon's avatar
Tim Moon committed
226

227
228
229
230
231
232
233
234
235
236
int cudart_version() {
  auto get_version = []() -> int {
    int version;
    NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version));
    return version;
  };
  static int version = get_version();
  return version;
}

237
238
239
240
241
242
size_t cublas_version() {
  // Cache version to avoid cuBLAS logging overhead
  static size_t version = cublasLtGetVersion();
  return version;
}

Tim Moon's avatar
Tim Moon committed
243
244
245
}  // namespace cuda

}  // namespace transformer_engine