"cmd/vscode:/vscode.git/clone" did not exist on "abfc4893f9d2abe9908fb0e407e5f67de1d0fce6"
Unverified Commit 2dd6b146 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Set CUDA context before loading NVRTC kernels (#734)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 16a469df
...@@ -76,6 +76,10 @@ Kernel::~Kernel() { ...@@ -76,6 +76,10 @@ Kernel::~Kernel() {
!= CUDA_SUCCESS) { != CUDA_SUCCESS) {
continue; continue;
} }
if (cuda_driver::call("cuCtxSetCurrent", context)
!= CUDA_SUCCESS) {
continue;
}
cuda_driver::call("cuModuleUnload", modules_[device_id]); cuda_driver::call("cuModuleUnload", modules_[device_id]);
cuda_driver::call("cuDevicePrimaryCtxRelease", device); cuda_driver::call("cuDevicePrimaryCtxRelease", device);
} }
...@@ -109,6 +113,7 @@ CUfunction Kernel::get_function(int device_id) { ...@@ -109,6 +113,7 @@ CUfunction Kernel::get_function(int device_id) {
CUcontext context; CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
// Load function into driver context // Load function into driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment