Unverified Commit 528e9b14 authored by Kebe's avatar Kebe Committed by GitHub
Browse files

[Feature][Core] Support Fabric detection to adapt the MNNVL protocol for the GB series (#33540)


Signed-off-by: default avatarKebe <mail@kebe7jun.com>
Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarThomas Vegas <tvegas@nvidia.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent d95b4be4
......@@ -115,11 +115,28 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if (flag) { // support GPUDirect RDMA if possible
prop.allocFlags.gpuDirectRDMACapable = 1;
}
int fab_flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&fab_flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
if (fab_flag) { // support fabric handle if possible
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
}
#endif
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
CUresult ret = (CUresult)cuMemCreate(p_memHandle, size, &prop, 0);
if (ret) {
if (fab_flag &&
(ret == CUDA_ERROR_NOT_PERMITTED || ret == CUDA_ERROR_NOT_SUPPORTED)) {
// Fabric allocation may fail without multi-node nvlink,
// fallback to POSIX file descriptor
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
} else {
CUDA_CHECK(ret);
}
}
if (error_code != 0) {
return;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment