Unverified Commit 6c942ffd authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Check CUDA driver KMD version for multicast symbol support (#1710)



Fixes #1692
Signed-off-by: default avatarNicolas Castet <26874160+nvcastet@users.noreply.github.com>
parent 6a969f0e
...@@ -114,9 +114,15 @@ bool supports_multicast(int device_id) { ...@@ -114,9 +114,15 @@ bool supports_multicast(int device_id) {
auto init = [&]() { auto init = [&]() {
CUdevice cudev; CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result; // Multicast support requires both CUDA12.1 UMD + KMD
int result = 0;
// Check if KMD >= 12.1
int driver_version;
NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
if (driver_version >= 12010) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
}
cache[device_id] = static_cast<bool>(result); cache[device_id] = static_cast<bool>(result);
}; };
std::call_once(flags[device_id], init); std::call_once(flags[device_id], init);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment