Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
719b29f2
Unverified
Commit
719b29f2
authored
Jul 18, 2025
by
Peng Zhang
Committed by
GitHub
Jul 18, 2025
Browse files
feat: enchance green context stream creation robust with backward compatibility (#8136)
parent
d0510f08
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
26 deletions
+33
-26
sgl-kernel/csrc/spatial/greenctx_stream.cu
sgl-kernel/csrc/spatial/greenctx_stream.cu
+33
-26
No files found.
sgl-kernel/csrc/spatial/greenctx_stream.cu
View file @
719b29f2
...
@@ -7,17 +7,15 @@
...
@@ -7,17 +7,15 @@
#include "cuda_utils.h"
#include "cuda_utils.h"
#include "greenctx_stream.h"
#include "greenctx_stream.h"
std
::
vector
<
int64_t
>
create_greenctx_stream_fallback
(
CUgreenCtx
gctx
[
2
])
{
static
std
::
vector
<
int64_t
>
create_greenctx_stream_fallback
(
CUgreenCtx
gctx
[
2
])
{
CUstream
streamA
,
streamB
;
CUstream
streamA
,
streamB
;
CUcontext
ctx
;
CUcontext
ctx
;
// Stream A
CUDA_DRV
(
cuCtxFromGreenCtx
(
&
ctx
,
gctx
[
0
]));
CUDA_DRV
(
cuCtxFromGreenCtx
(
&
ctx
,
gctx
[
0
]));
CUDA_DRV
(
cuCtxPushCurrent
(
ctx
));
CUDA_DRV
(
cuCtxPushCurrent
(
ctx
));
CUDA_DRV
(
cuStreamCreate
(
&
streamA
,
CU_STREAM_NON_BLOCKING
));
CUDA_DRV
(
cuStreamCreate
(
&
streamA
,
CU_STREAM_NON_BLOCKING
));
CUDA_DRV
(
cuCtxPopCurrent
(
nullptr
));
CUDA_DRV
(
cuCtxPopCurrent
(
nullptr
));
// Stream B
CUDA_DRV
(
cuCtxFromGreenCtx
(
&
ctx
,
gctx
[
1
]));
CUDA_DRV
(
cuCtxFromGreenCtx
(
&
ctx
,
gctx
[
1
]));
CUDA_DRV
(
cuCtxPushCurrent
(
ctx
));
CUDA_DRV
(
cuCtxPushCurrent
(
ctx
));
CUDA_DRV
(
cuStreamCreate
(
&
streamB
,
CU_STREAM_NON_BLOCKING
));
CUDA_DRV
(
cuStreamCreate
(
&
streamB
,
CU_STREAM_NON_BLOCKING
));
...
@@ -26,18 +24,31 @@ std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
...
@@ -26,18 +24,31 @@ std::vector<int64_t> create_greenctx_stream_fallback(CUgreenCtx gctx[2]) {
return
{(
int64_t
)
streamA
,
(
int64_t
)
streamB
};
return
{(
int64_t
)
streamA
,
(
int64_t
)
streamB
};
}
}
#if CUDA_VERSION >= 12050
typedef
CUresult
(
CUDAAPI
*
PFN_cuGreenCtxStreamCreate
)(
CUstream
*
,
CUgreenCtx
,
unsigned
int
,
int
);
std
::
vector
<
int64_t
>
create_greenctx_stream_direct
(
CUgreenCtx
gctx
[
2
])
{
CUstream
streamA
;
CUstream
streamB
;
CUDA_DRV
(
cuGreenCtxStreamCreate
(
&
streamA
,
gctx
[
0
],
CU_STREAM_NON_BLOCKING
,
0
));
static
std
::
vector
<
int64_t
>
create_greenctx_stream_direct_dynamic
(
CUgreenCtx
gctx
[
2
])
{
CUDA_DRV
(
cuGreenCtxStreamCreate
(
&
streamB
,
gctx
[
1
],
CU_STREAM_NON_BLOCKING
,
0
));
static
PFN_cuGreenCtxStreamCreate
pfn
=
nullptr
;
static
std
::
once_flag
pfn_probed_flag
;
std
::
vector
<
int64_t
>
vec
=
{(
int64_t
)
streamA
,
(
int64_t
)
streamB
};
// detect compatibility in runtime
return
vec
;
std
::
call_once
(
pfn_probed_flag
,
[]()
{
cuGetProcAddress
(
"cuGreenCtxStreamCreate"
,
reinterpret_cast
<
void
**>
(
&
pfn
),
0
,
0
,
nullptr
);
});
if
(
!
pfn
)
{
// fallback if not compatible
return
create_greenctx_stream_fallback
(
gctx
);
}
CUstream
streamA
,
streamB
;
CUDA_DRV
(
pfn
(
&
streamA
,
gctx
[
0
],
CU_STREAM_NON_BLOCKING
,
0
));
CUDA_DRV
(
pfn
(
&
streamB
,
gctx
[
1
],
CU_STREAM_NON_BLOCKING
,
0
));
return
{(
int64_t
)
streamA
,
(
int64_t
)
streamB
};
}
inline
void
destroy_green_context
(
int64_t
h
)
{
if
(
h
)
CUDA_DRV
(
cuGreenCtxDestroy
(
reinterpret_cast
<
CUgreenCtx
>
(
h
)));
}
}
#endif
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
)
{
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
)
{
TORCH_CHECK
(
CUDA_VERSION
>=
12040
,
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
);
TORCH_CHECK
(
CUDA_VERSION
>=
12040
,
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
);
...
@@ -46,42 +57,38 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
...
@@ -46,42 +57,38 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
CUdevResourceDesc
desc
[
3
];
CUdevResourceDesc
desc
[
3
];
CUdevResource
input
;
CUdevResource
input
;
CUdevResource
resources
[
4
];
CUdevResource
resources
[
4
];
unsigned
int
nbGroups
=
1
;
if
(
smA
<=
0
||
smB
<=
0
)
{
if
(
smA
<=
0
||
smB
<=
0
)
{
TORCH_CHECK
(
false
,
"SM counts must be positive"
);
TORCH_CHECK
(
false
,
"SM counts must be positive"
);
}
}
CUDA_DRV
(
cuDeviceGetDevResource
((
CUdevice
)
device
,
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
CUDA_DRV
(
cuDeviceGetDevResource
((
CUdevice
)
device
,
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
unsigned
int
minCount
=
(
unsigned
int
)(
smA
+
smB
);
unsigned
int
minCountA
=
(
unsigned
int
)(
smA
);
const
unsigned
minCount
=
smA
+
smB
;
const
unsigned
minCountA
=
smA
;
TORCH_CHECK
(
minCount
<=
input
.
sm
.
smCount
,
"Not enough SMs available for the requested configuration"
);
TORCH_CHECK
(
minCount
<=
input
.
sm
.
smCount
,
"Not enough SMs available for the requested configuration"
);
unsigned
nbGroups
=
1
;
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
2
],
&
nbGroups
,
&
input
,
&
resources
[
3
],
0
,
minCount
));
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
2
],
&
nbGroups
,
&
input
,
&
resources
[
3
],
0
,
minCount
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
2
],
&
resources
[
2
],
1
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
2
],
&
resources
[
2
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
2
],
desc
[
2
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
2
],
desc
[
2
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxGetDevResource
(
gctx
[
2
],
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
CUDA_DRV
(
cuGreenCtxGetDevResource
(
gctx
[
2
],
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
nbGroups
=
1
;
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
0
],
&
nbGroups
,
&
input
,
&
resources
[
1
],
0
,
minCountA
));
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
0
],
&
nbGroups
,
&
input
,
&
resources
[
1
],
0
,
minCountA
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
0
],
&
resources
[
0
],
1
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
0
],
&
resources
[
0
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
0
],
desc
[
0
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
0
],
desc
[
0
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
1
],
&
resources
[
1
],
1
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
1
],
&
resources
[
1
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
1
],
desc
[
1
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
1
],
desc
[
1
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
int
smCountA
=
resources
[
0
].
sm
.
smCount
;
int
smCountB
=
resources
[
1
].
sm
.
smCount
;
std
::
vector
<
int64_t
>
stream_handles
;
const
int
smCountA
=
resources
[
0
].
sm
.
smCount
;
const
int
smCountB
=
resources
[
1
].
sm
.
smCount
;
#if CUDA_VERSION >= 12050
std
::
vector
<
int64_t
>
streams
=
create_greenctx_stream_direct_dynamic
(
gctx
);
stream_handles
=
create_greenctx_stream_direct
(
gctx
);
#else
stream_handles
=
create_greenctx_stream_fallback
(
gctx
);
#endif
CUDA_DRV
(
cuGreenCtxDestroy
(
gctx
[
2
]));
CUDA_DRV
(
cuGreenCtxDestroy
(
gctx
[
2
]));
std
::
vector
<
int64_t
>
vec
=
{
std
::
vector
<
int64_t
>
vec
=
{
stream
_handle
s
[
0
],
// streamA
streams
[
0
],
// streamA
stream
_handle
s
[
1
],
// streamB
streams
[
1
],
// streamB
(
int64_t
)
smCountA
,
(
int64_t
)
smCountA
,
(
int64_t
)
smCountB
};
(
int64_t
)
smCountB
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment