cuda_runtime.cpp 6.4 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
#include "../util/cuda_runtime.h"

Tim Moon's avatar
Tim Moon committed
9
10
11
12
13
14
#include <filesystem>
#include <mutex>

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/system.h"
15
#include "common/util/cuda_runtime.h"
Tim Moon's avatar
Tim Moon committed
16
17
18
19
20
21
22
23
24
25
26
27
28

namespace transformer_engine {

namespace cuda {

namespace {

// String with build-time CUDA include path
#include "string_path_cuda_include.h"

}  // namespace

int num_devices() {
29
  auto query_num_devices = []() -> int {
Tim Moon's avatar
Tim Moon committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    int count;
    NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
    return count;
  };
  static int num_devices_ = query_num_devices();
  return num_devices_;
}

int current_device() {
  // Return 0 if CUDA context is not initialized
  CUcontext context;
  NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
  if (context == nullptr) {
    return 0;
  }

  // Query device from CUDA runtime
  int device_id;
  NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
  return device_id;
}

int sm_arch(int device_id) {
  static std::vector<int> cache(num_devices(), -1);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
59
  auto init = [&]() {
Tim Moon's avatar
Tim Moon committed
60
61
    cudaDeviceProp prop;
    NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
62
    cache[device_id] = 10 * prop.major + prop.minor;
Tim Moon's avatar
Tim Moon committed
63
64
65
66
67
68
69
70
71
72
73
74
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
}

int sm_count(int device_id) {
  static std::vector<int> cache(num_devices(), -1);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
75
  auto init = [&]() {
Tim Moon's avatar
Tim Moon committed
76
77
78
79
80
81
82
83
    cudaDeviceProp prop;
    NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
    cache[device_id] = prop.multiProcessorCount;
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
}

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
void stream_priority_range(int *low_priority, int *high_priority, int device_id) {
  static std::vector<std::pair<int, int>> cache(num_devices());
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
  auto init = [&]() {
    int ori_dev = current_device();
    if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id));
    int min_pri, max_pri;
    NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri));
    if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev));
    cache[device_id] = std::make_pair(min_pri, max_pri);
  };
  std::call_once(flags[device_id], init);
  *low_priority = cache[device_id].first;
  *high_priority = cache[device_id].second;
}

104
105
bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010
106
  // NOTE: This needs to be guarded at compile-time and run-time because the
107
  //       CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
108
109
110
  if (cudart_version() < 12010) {
    return false;
  }
111
112
113
114
115
116
117
118
119
  static std::vector<bool> cache(num_devices(), false);
  static std::vector<std::once_flag> flags(num_devices());
  if (device_id < 0) {
    device_id = current_device();
  }
  NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
  auto init = [&]() {
    CUdevice cudev;
    NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
120
121
122
123
124
125
126
127
128
    // Multicast support requires both CUDA12.1 UMD + KMD
    int result = 0;
    // Check if KMD >= 12.1
    int driver_version;
    NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
    if (driver_version >= 12010) {
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
                                  CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
    }
129
130
131
132
133
134
135
136
137
    cache[device_id] = static_cast<bool>(result);
  };
  std::call_once(flags[device_id], init);
  return cache[device_id];
#else
  return false;
#endif
}

Tim Moon's avatar
Tim Moon committed
138
139
140
141
142
143
144
145
146
147
148
const std::string &include_directory(bool required) {
  static std::string path;

  // Update cached path if needed
  static bool need_to_check_env = true;
  if (path.empty() && required) {
    need_to_check_env = true;
  }
  if (need_to_check_env) {
    // Search for CUDA headers in common paths
    using Path = std::filesystem::path;
149
150
151
152
153
    std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
                                                              {"CUDA_HOME", ""},
                                                              {"CUDA_DIR", ""},
                                                              {"", string_path_cuda_include},
                                                              {"", "/usr/local/cuda"}};
Tim Moon's avatar
Tim Moon committed
154
155
    for (auto &[env, p] : search_paths) {
      if (p.empty()) {
156
        p = getenv<Path>(env.c_str());
Tim Moon's avatar
Tim Moon committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
      }
      if (!p.empty()) {
        if (file_exists(p / "cuda_runtime.h")) {
          path = p;
          break;
        }
        if (file_exists(p / "include" / "cuda_runtime.h")) {
          path = p / "include";
          break;
        }
      }
    }

    // Throw exception if path is required but not found
    if (path.empty() && required) {
      std::string message;
      message.reserve(2048);
      message += "Could not find cuda_runtime.h in";
      bool is_first = true;
      for (const auto &[env, p] : search_paths) {
        message += is_first ? " " : ", ";
        is_first = false;
        if (!env.empty()) {
          message += env;
          message += "=";
        }
        if (p.empty()) {
          message += "<unset>";
        } else {
          message += p;
        }
      }
189
190
191
192
193
      message +=
          (". "
           "Specify path to CUDA Toolkit headers "
           "with NVTE_CUDA_INCLUDE_DIR "
           "or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
Tim Moon's avatar
Tim Moon committed
194
195
196
197
198
199
200
201
202
      NVTE_ERROR(message);
    }
    need_to_check_env = false;
  }

  // Return cached path
  return path;
}

203
204
205
206
207
208
209
210
211
212
int cudart_version() {
  auto get_version = []() -> int {
    int version;
    NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version));
    return version;
  };
  static int version = get_version();
  return version;
}

Tim Moon's avatar
Tim Moon committed
213
214
215
}  // namespace cuda

}  // namespace transformer_engine