Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
1ebec1a8
Unverified
Commit
1ebec1a8
authored
Jul 15, 2025
by
ykcombat
Committed by
GitHub
Jul 15, 2025
Browse files
[Feature] CUDA Green Context Support (#7649)
parent
d4d0c7c3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
190 additions
and
0 deletions
+190
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+6
-0
sgl-kernel/csrc/spatial/cuda_utils.h
sgl-kernel/csrc/spatial/cuda_utils.h
+44
-0
sgl-kernel/csrc/spatial/greenctx_stream.cu
sgl-kernel/csrc/spatial/greenctx_stream.cu
+58
-0
sgl-kernel/csrc/spatial/greenctx_stream.h
sgl-kernel/csrc/spatial/greenctx_stream.h
+2
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+5
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/spatial.py
sgl-kernel/python/sgl_kernel/spatial.py
+48
-0
sgl-kernel/tests/spatial/test_greenctx_stream.py
sgl-kernel/tests/spatial/test_greenctx_stream.py
+25
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
1ebec1a8
...
@@ -246,6 +246,7 @@ set(SOURCES
...
@@ -246,6 +246,7 @@ set(SOURCES
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/packbit.cu"
"csrc/spatial/greenctx_stream.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/kvcacheio/transfer.cu"
...
...
sgl-kernel/csrc/common_extension.cc
View file @
1ebec1a8
...
@@ -401,6 +401,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -401,6 +401,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()"
);
"Tensor _ascales, Tensor! _out_feats) -> ()"
);
m
.
impl
(
"qserve_w4a8_per_group_gemm"
,
torch
::
kCUDA
,
&
qserve_w4a8_per_group_gemm
);
m
.
impl
(
"qserve_w4a8_per_group_gemm"
,
torch
::
kCUDA
,
&
qserve_w4a8_per_group_gemm
);
/*
* From csrc/spatial
*/
m
.
def
(
"create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"
);
m
.
impl
(
"create_greenctx_stream_by_value"
,
&
create_greenctx_stream_by_value
);
}
}
REGISTER_EXTENSION
(
common_ops
)
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/spatial/cuda_utils.h
0 → 100644
View file @
1ebec1a8
#include <cuda.h>
#include <cuda_runtime.h>
#define CUDA_RT(call) \
do { \
cudaError_t _status = (call); \
if (_status != cudaSuccess) { \
std::cerr << "ERROR: CUDA RT call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
<< " failed with " << cudaGetErrorString(_status) << std::endl; \
TORCH_CHECK( \
false, \
c10::str( \
"ERROR: CUDA RT call \"", \
#call, \
"\" in line ", \
__LINE__, \
" of file ", \
__FILE__, \
" failed with ", \
cudaGetErrorString(_status))); \
} \
} while (0)
#define CUDA_DRV(call) \
do { \
CUresult _status = (call); \
if (_status != CUDA_SUCCESS) { \
const char* err_str; \
cuGetErrorString(_status, &err_str); \
std::cerr << "ERROR: CUDA DRV call \"" << #call << "\" in line " << __LINE__ << " of file " << __FILE__ \
<< " failed with " << err_str << std::endl; \
TORCH_CHECK( \
false, \
c10::str( \
"ERROR: CUDA DRV call \"", \
#call, \
"\" in line ", \
__LINE__, \
" of file ", \
__FILE__, \
" failed with ", \
err_str)); \
} \
} while (0)
sgl-kernel/csrc/spatial/greenctx_stream.cu
0 → 100644
View file @
1ebec1a8
#include <torch/all.h>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include "cuda_utils.h"
#include "greenctx_stream.h"
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
)
{
CUgreenCtx
gctx
[
3
];
CUdevResourceDesc
desc
[
3
];
CUdevResource
input
;
CUdevResource
resources
[
4
];
CUstream
streamA
;
CUstream
streamB
;
unsigned
int
nbGroups
=
1
;
if
(
smA
<=
0
||
smB
<=
0
)
{
TORCH_CHECK
(
false
,
"SM counts must be positive"
);
}
// Initialize device
CUDA_RT
(
cudaInitDevice
(
device
,
0
,
0
));
// Query input SMs
CUDA_DRV
(
cuDeviceGetDevResource
((
CUdevice
)
device
,
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
// We want 3/4 the device for our green context
unsigned
int
minCount
=
(
unsigned
int
)(
smA
+
smB
);
unsigned
int
minCountA
=
(
unsigned
int
)(
smA
);
TORCH_CHECK
(
minCount
<=
input
.
sm
.
smCount
,
"Not enough SMs available for the requested configuration"
);
// Split resources
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
2
],
&
nbGroups
,
&
input
,
&
resources
[
3
],
0
,
minCount
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
2
],
&
resources
[
2
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
2
],
desc
[
2
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxGetDevResource
(
gctx
[
2
],
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
CUDA_DRV
(
cuDevSmResourceSplitByCount
(
&
resources
[
0
],
&
nbGroups
,
&
input
,
&
resources
[
1
],
0
,
minCountA
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
0
],
&
resources
[
0
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
0
],
desc
[
0
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuDevResourceGenerateDesc
(
&
desc
[
1
],
&
resources
[
1
],
1
));
CUDA_DRV
(
cuGreenCtxCreate
(
&
gctx
[
1
],
desc
[
1
],
(
CUdevice
)
device
,
CU_GREEN_CTX_DEFAULT_STREAM
));
CUDA_DRV
(
cuGreenCtxStreamCreate
(
&
streamA
,
gctx
[
0
],
CU_STREAM_NON_BLOCKING
,
0
));
CUDA_DRV
(
cuGreenCtxStreamCreate
(
&
streamB
,
gctx
[
1
],
CU_STREAM_NON_BLOCKING
,
0
));
int
smCountA
=
resources
[
0
].
sm
.
smCount
;
int
smCountB
=
resources
[
1
].
sm
.
smCount
;
CUDA_DRV
(
cuGreenCtxDestroy
(
gctx
[
2
]));
std
::
vector
<
int64_t
>
vec
=
{(
int64_t
)
streamA
,
(
int64_t
)
streamB
,
smCountA
,
smCountB
};
return
vec
;
}
sgl-kernel/csrc/spatial/greenctx_stream.h
0 → 100644
View file @
1ebec1a8
#include <vector>
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
);
sgl-kernel/include/sgl_kernel_ops.h
View file @
1ebec1a8
...
@@ -661,3 +661,8 @@ void qserve_w4a8_per_group_gemm(
...
@@ -661,3 +661,8 @@ void qserve_w4a8_per_group_gemm(
const
torch
::
Tensor
&
_wscales
,
const
torch
::
Tensor
&
_wscales
,
const
torch
::
Tensor
&
_ascales
,
const
torch
::
Tensor
&
_ascales
,
torch
::
Tensor
&
_out_feats
);
torch
::
Tensor
&
_out_feats
);
/*
* From csrc/spatial
*/
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
);
sgl-kernel/python/sgl_kernel/__init__.py
View file @
1ebec1a8
...
@@ -81,6 +81,7 @@ from sgl_kernel.sampling import (
...
@@ -81,6 +81,7 @@ from sgl_kernel.sampling import (
top_p_renorm_prob
,
top_p_renorm_prob
,
top_p_sampling_from_probs
,
top_p_sampling_from_probs
,
)
)
from
sgl_kernel.spatial
import
create_greenctx_stream_by_value
,
get_sm_available
from
sgl_kernel.speculative
import
(
from
sgl_kernel.speculative
import
(
build_tree_kernel_efficient
,
build_tree_kernel_efficient
,
segment_packbits
,
segment_packbits
,
...
...
sgl-kernel/python/sgl_kernel/spatial.py
0 → 100644
View file @
1ebec1a8
import
torch
from
torch.cuda.streams
import
ExternalStream
def
create_greenctx_stream_by_value
(
SM_a
:
int
,
SM_b
:
int
,
device_id
:
int
=
None
)
->
tuple
[
ExternalStream
,
ExternalStream
]:
"""
Create two streams for greenctx.
Args:
sm_A (int): The SM of stream A.
sm_B (int): The weight of stream B.
device_id (int): The device id.
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if
device_id
is
None
:
device_id
=
torch
.
cuda
.
current_device
()
res
=
torch
.
ops
.
sgl_kernel
.
create_greenctx_stream_by_value
(
SM_a
,
SM_b
,
device_id
)
stream_a
=
ExternalStream
(
stream_ptr
=
res
[
0
],
device
=
torch
.
device
(
f
"cuda:
{
device_id
}
"
)
)
stream_b
=
ExternalStream
(
stream_ptr
=
res
[
1
],
device
=
torch
.
device
(
f
"cuda:
{
device_id
}
"
)
)
return
stream_a
,
stream_b
def
get_sm_available
(
device_id
:
int
=
None
)
->
int
:
"""
Get the SMs available on the device.
Args:
device_id (int): The device id.
Returns:
int: The SMs available.
"""
if
device_id
is
None
:
device_id
=
torch
.
cuda
.
current_device
()
device_props
=
torch
.
cuda
.
get_device_properties
(
device_id
)
# Get the number of Streaming Multiprocessors (SMs)
sm_count
=
device_props
.
multi_processor_count
return
sm_count
sgl-kernel/tests/spatial/test_greenctx_stream.py
0 → 100644
View file @
1ebec1a8
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
create_greenctx_stream_by_value
,
get_sm_available
def
test_green_ctx
():
A
=
torch
.
randn
(
5120
,
5120
).
cuda
()
B
=
torch
.
randn
(
5120
,
5120
).
cuda
()
C
=
torch
.
matmul
(
A
,
B
)
sm_counts
=
get_sm_available
(
0
)
stream_group
=
create_greenctx_stream_by_value
(
sm_counts
//
2
,
sm_counts
//
2
,
0
)
with
torch
.
cuda
.
stream
(
stream_group
[
0
]):
for
_
in
range
(
100
):
result_0
=
torch
.
matmul
(
A
,
B
)
with
torch
.
cuda
.
stream
(
stream_group
[
1
]):
for
_
in
range
(
100
):
result_1
=
torch
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
assert
torch
.
allclose
(
result_0
,
C
)
assert
torch
.
allclose
(
result_1
,
C
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment