greenctx_stream.cu 3.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include <torch/all.h>

#include <cstdlib>
#include <iomanip>
#include <iostream>

#include "cuda_utils.h"
#include "greenctx_stream.h"

10
11
#if CUDA_VERSION >= 12040

12
static std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  CUstream streamA, streamB;
  CUcontext ctx;

  CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0]));
  CUDA_DRV(cuCtxPushCurrent(ctx));
  CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING));
  CUDA_DRV(cuCtxPopCurrent(nullptr));

  CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1]));
  CUDA_DRV(cuCtxPushCurrent(ctx));
  CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING));
  CUDA_DRV(cuCtxPopCurrent(nullptr));

  return {(int64_t)streamA, (int64_t)streamB};
}

29
typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int);
30

31
32
33
static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) {
  static PFN_cuGreenCtxStreamCreate pfn = nullptr;
  static std::once_flag pfn_probed_flag;
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  // detect compatibility in runtime
  std::call_once(pfn_probed_flag, []() {
    cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast<void**>(&pfn), 0, 0, nullptr);
  });

  if (!pfn) {  // fallback if not compatible
    return create_greenctx_stream_fallback(gctx);
  }

  CUstream streamA, streamB;
  CUDA_DRV(pfn(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0));
  CUDA_DRV(pfn(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0));

  return {(int64_t)streamA, (int64_t)streamB};
}

inline void destroy_green_context(int64_t h) {
  if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast<CUgreenCtx>(h)));
53
54
}

55
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
56
57
  TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");

58
59
60
61
62
63
64
65
66
  CUgreenCtx gctx[3];
  CUdevResourceDesc desc[3];
  CUdevResource input;
  CUdevResource resources[4];
  if (smA <= 0 || smB <= 0) {
    TORCH_CHECK(false, "SM counts must be positive");
  }

  CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
67
68
69

  const unsigned minCount = smA + smB;
  const unsigned minCountA = smA;
70
71
  TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration");

72
  unsigned nbGroups = 1;
73
74
75
76
  CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount));
  CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1));
  CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
  CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM));
77
  nbGroups = 1;
78
79
80
81
82
83
  CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA));
  CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1));
  CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));
  CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1));
  CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM));

84
85
  const int smCountA = resources[0].sm.smCount;
  const int smCountB = resources[1].sm.smCount;
86

87
  std::vector<int64_t> streams = create_greenctx_stream_direct_dynamic(gctx);
88

89
90
  CUDA_DRV(cuGreenCtxDestroy(gctx[2]));

91
  std::vector<int64_t> vec = {
92
93
      streams[0],  // streamA
      streams[1],  // streamB
94
95
96
      (int64_t)smCountA,
      (int64_t)smCountB};

97
98
  return vec;
}
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

#else

std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
  TORCH_CHECK(
      false,
      "Green Contexts feature requires CUDA Toolkit 12.4 or newer. Current CUDA version: " +
          std::to_string(CUDA_VERSION));

  // This is a stub function that should never be reached
  // Return empty vector to satisfy return type requirement
  return {};
}

#endif