"maint/host_checks/08_device_id_mismatch.py" did not exist on "b8240b7ae9387ba7143e6243b59069c3a04a12e9"
Unverified Commit 445f9dca authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Runtime check CUDA driver version to avoid unresolved green context symbols (#9021)

parent 3a9afe2a
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#define CUDA_RT(call) \
do { \
cudaError_t _status = (call); \
......
// Documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
#include <torch/all.h>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include "cuda_utils.h"
#include "greenctx_stream.h"
#if CUDA_VERSION >= 12040
static int CUDA_DRIVER_VERSION;
using PFN_cuGreenCtxStreamCreate = CUresult(CUDAAPI*)(CUstream*, CUgreenCtx, unsigned int, int);
auto probe_cuGreenCtxStreamCreate() -> PFN_cuGreenCtxStreamCreate {
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
CUDA_DRV(cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), CUDA_DRIVER_VERSION, 0, nullptr));
return pfn;
}
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
CUstream streamA, streamB;
......@@ -26,18 +33,15 @@ static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2])
return {(int64_t)streamA, (int64_t)streamB};
}
typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int);
inline void destroy_green_context(CUgreenCtx gctx) {
if (!gctx) return;
CUDA_DRV(cuGreenCtxDestroy(gctx));
}
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
static PFN_cuGreenCtxStreamCreate pfn = nullptr;
static std::once_flag pfn_probed_flag;
// detect compatibility in runtime
std::call_once(pfn_probed_flag, []() {
cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
});
if (!pfn) { // fallback if not compatible
// This symbol is introduced in CUDA 12.5
const static auto pfn = probe_cuGreenCtxStreamCreate();
if (!pfn) {
return create_greenctx_stream_fallback(gctx);
}
......@@ -48,12 +52,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
return {(int64_t)streamA, (int64_t)streamB};
}
inline void destroy_green_context(int64_t h) {
if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
}
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
if (CUDA_DRIVER_VERSION < 12040) {
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
}
CUgreenCtx gctx[3];
CUdevResourceDesc desc[3];
......@@ -65,8 +69,8 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
const unsigned minCount = smA + smB;
const unsigned minCountA = smA;
const unsigned minCount = static_cast<unsigned>(smA + smB);
const unsigned minCountA = static_cast<unsigned>(smA);
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
unsigned nbGroups = 1;
......@@ -86,7 +90,7 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
CUDA_DRV(cuGreenCtxDestroy(gctx[2]));
destroy_green_context(gctx[2]);
std::vector<int64_t> vec = {
streams[0], // streamA
......@@ -96,18 +100,3 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
return vec;
}
#else
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
TORCH_CHECK(
false,
"Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " +
std::to_string(CUDA_VERSION));
// This is a stub function that should never be reached
// Return empty vector to satisfy return type requirement
return {};
}
#endif
......@@ -14,6 +14,7 @@
# ==============================================================================
import functools
import subprocess
from typing import Dict, Tuple
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment