Unverified Commit 719b29f2 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

feat: enchance green context stream creation robust with backward compatibility (#8136)

parent d0510f08
...@@ -7,17 +7,15 @@ ...@@ -7,17 +7,15 @@
#include "cuda_utils.h" #include "cuda_utils.h"
#include "greenctx_stream.h" #include "greenctx_stream.h"
std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
CUstream streamA, streamB; CUstream streamA, streamB;
CUcontext ctx; CUcontext ctx;
// Stream A
CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0])); CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0]));
CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuCtxPushCurrent(ctx));
CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING)); CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING));
CUDA_DRV(cuCtxPopCurrent(nullptr)); CUDA_DRV(cuCtxPopCurrent(nullptr));
// Stream B
CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1])); CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1]));
CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuCtxPushCurrent(ctx));
CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING)); CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING));
...@@ -26,18 +24,31 @@ std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { ...@@ -26,18 +24,31 @@ std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
return {(int64_t)streamA, (int64_t)streamB}; return {(int64_t)streamA, (int64_t)streamB};
} }
#if CUDA_VERSION >= 12050 typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int);
std::vector<int64_t> create_greenctx_stream_direct(CUgreenCtx gctx[2]) {
CUstream streamA;
CUstream streamB;
CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); static PFN_cuGreenCtxStreamCreate pfn = nullptr;
static std::once_flag pfn_probed_flag;
std::vector<int64_t> vec = {(int64_t)streamA, (int64_t)streamB}; // detect compatibility in runtime
return vec; std::call_once(pfn_probed_flag, []() {
cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
});
if (!pfn) { // fallback if not compatible
return create_greenctx_stream_fallback(gctx);
}
CUstream streamA, streamB;
CUDA_DRV(pfn(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0));
CUDA_DRV(pfn(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0));
return {(int64_t)streamA, (int64_t)streamB};
}
inline void destroy_green_context(int64_t h) {
if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
} }
#endif
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
...@@ -46,42 +57,38 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i ...@@ -46,42 +57,38 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
CUdevResourceDesc desc[3]; CUdevResourceDesc desc[3];
CUdevResource input; CUdevResource input;
CUdevResource resources[4]; CUdevResource resources[4];
unsigned int nbGroups = 1;
if (smA <= 0 || smB <= 0) { if (smA <= 0 || smB <= 0) {
TORCH_CHECK(false, "SM counts must be positive"); TORCH_CHECK(false, "SM counts must be positive");
} }
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
unsigned int minCount = (unsigned int)(smA + smB);
unsigned int minCountA = (unsigned int)(smA); const unsigned minCount = smA + smB;
const unsigned minCountA = smA;
TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");
unsigned nbGroups = 1;
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM));
nbGroups = 1;
CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1));
CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
int smCountA = resources[0].sm.smCount;
int smCountB = resources[1].sm.smCount;
std::vector<int64_t> stream_handles; const int smCountA = resources[0].sm.smCount;
const int smCountB = resources[1].sm.smCount;
#if CUDA_VERSION >= 12050 std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
stream_handles = create_greenctx_stream_direct(gctx);
#else
stream_handles = create_greenctx_stream_fallback(gctx);
#endif
CUDA_DRV(cuGreenCtxDestroy(gctx[2])); CUDA_DRV(cuGreenCtxDestroy(gctx[2]));
std::vector<int64_t> vec = { std::vector<int64_t> vec = {
stream_handles[0], // streamA streams[0], // streamA
stream_handles[1], // streamB streams[1], // streamB
(int64_t)smCountA, (int64_t)smCountA,
(int64_t)smCountB}; (int64_t)smCountB};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment