Unverified Commit 445f9dca authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Runtime check CUDA driver version to avoid unresolved green context symbols (#9021)

parent 3a9afe2a
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream>
#define CUDA_RT(call) \ #define CUDA_RT(call) \
do { \ do { \
cudaError_t _status = (call); \ cudaError_t _status = (call); \
......
// Documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
#include <torch/all.h> #include <torch/all.h>
#include <cstdlib> #include <cstdlib>
#include <iomanip>
#include <iostream>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "greenctx_stream.h" #include "greenctx_stream.h"
#if CUDA_VERSION >= 12040 static int CUDA_DRIVER_VERSION;
using PFN_cuGreenCtxStreamCreate = CUresult(CUDAAPI*)(CUstream*, CUgreenCtx, unsigned int, int);
auto probe_cuGreenCtxStreamCreate() -> PFN_cuGreenCtxStreamCreate {
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
CUDA_DRV(cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), CUDA_DRIVER_VERSION, 0, nullptr));
return pfn;
}
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
CUstream streamA, streamB; CUstream streamA, streamB;
...@@ -26,18 +33,15 @@ static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) ...@@ -26,18 +33,15 @@ static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2])
return {(int64_t)streamA, (int64_t)streamB}; return {(int64_t)streamA, (int64_t)streamB};
} }
typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int); inline void destroy_green_context(CUgreenCtx gctx) {
if (!gctx) return;
CUDA_DRV(cuGreenCtxDestroy(gctx));
}
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) { static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
static PFN_cuGreenCtxStreamCreate pfn = nullptr; // This symbol is introduced in CUDA 12.5
static std::once_flag pfn_probed_flag; const static auto pfn = probe_cuGreenCtxStreamCreate();
if (!pfn) {
// detect compatibility in runtime
std::call_once(pfn_probed_flag, []() {
cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
});
if (!pfn) { // fallback if not compatible
return create_greenctx_stream_fallback(gctx); return create_greenctx_stream_fallback(gctx);
} }
...@@ -48,12 +52,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct ...@@ -48,12 +52,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
return {(int64_t)streamA, (int64_t)streamB}; return {(int64_t)streamA, (int64_t)streamB};
} }
inline void destroy_green_context(int64_t h) {
if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
}
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
if (CUDA_DRIVER_VERSION < 12040) {
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
}
CUgreenCtx gctx[3]; CUgreenCtx gctx[3];
CUdevResourceDesc desc[3]; CUdevResourceDesc desc[3];
...@@ -65,8 +69,8 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i ...@@ -65,8 +69,8 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
const unsigned minCount = smA + smB; const unsigned minCount = static_cast<unsigned>(smA + smB);
const unsigned minCountA = smA; const unsigned minCountA = static_cast<unsigned>(smA);
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
unsigned nbGroups = 1; unsigned nbGroups = 1;
...@@ -86,7 +90,7 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i ...@@ -86,7 +90,7 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx); std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
CUDA_DRV(cuGreenCtxDestroy(gctx[2])); destroy_green_context(gctx[2]);
std::vector<int64_t> vec = { std::vector<int64_t> vec = {
streams[0], // streamA streams[0], // streamA
...@@ -96,18 +100,3 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i ...@@ -96,18 +100,3 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
return vec; return vec;
} }
#else
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
TORCH_CHECK(
false,
"Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " +
std::to_string(CUDA_VERSION));
// This is a stub function that should never be reached
// Return empty vector to satisfy return type requirement
return {};
}
#endif
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
import functools import functools
import subprocess
from typing import Dict, Tuple from typing import Dict, Tuple
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment