Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c8594e6
Unverified
Commit
0c8594e6
authored
Aug 15, 2025
by
Liangsheng Yin
Committed by
GitHub
Aug 15, 2025
Browse files
Optional extension for green context (#9231)
parent
c186feed
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
20 deletions
+73
-20
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+12
-1
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+0
-6
sgl-kernel/csrc/spatial/greenctx_stream.cu
sgl-kernel/csrc/spatial/greenctx_stream.cu
+3
-7
sgl-kernel/csrc/spatial_extension.cc
sgl-kernel/csrc/spatial_extension.cc
+29
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+14
-1
sgl-kernel/python/sgl_kernel/spatial.py
sgl-kernel/python/sgl_kernel/spatial.py
+15
-5
No files found.
sgl-kernel/CMakeLists.txt
View file @
0c8594e6
...
@@ -274,7 +274,6 @@ set(SOURCES
...
@@ -274,7 +274,6 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/packbit.cu"
"csrc/spatial/greenctx_stream.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/memory/store.cu"
"csrc/memory/store.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
"
${
repo-flashinfer_SOURCE_DIR
}
/csrc/norm.cu"
...
@@ -417,6 +416,18 @@ if (SGL_KERNEL_ENABLE_FA3)
...
@@ -417,6 +416,18 @@ if (SGL_KERNEL_ENABLE_FA3)
target_compile_definitions
(
flash_ops PRIVATE
${
FLASH_OPS_COMPILE_DEFS
}
)
target_compile_definitions
(
flash_ops PRIVATE
${
FLASH_OPS_COMPILE_DEFS
}
)
endif
()
endif
()
# Build spatial_ops as a separate, optional extension for green contexts
set
(
SPATIAL_SOURCES
"csrc/spatial/greenctx_stream.cu"
"csrc/spatial_extension.cc"
)
Python_add_library
(
spatial_ops MODULE USE_SABI
${
SKBUILD_SABI_VERSION
}
WITH_SOABI
${
SPATIAL_SOURCES
}
)
target_compile_options
(
spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
${
SGL_KERNEL_CUDA_FLAGS
}
>
)
target_link_libraries
(
spatial_ops PRIVATE
${
TORCH_LIBRARIES
}
c10 cuda
)
install
(
TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel
)
# ============================ DeepGEMM (JIT) ============================= #
# ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
# This keeps its compilation isolated from the main common_ops.
...
...
sgl-kernel/csrc/common_extension.cc
View file @
0c8594e6
...
@@ -433,12 +433,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -433,12 +433,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()"
);
"Tensor _ascales, Tensor! _out_feats) -> ()"
);
m
.
impl
(
"qserve_w4a8_per_group_gemm"
,
torch
::
kCUDA
,
&
qserve_w4a8_per_group_gemm
);
m
.
impl
(
"qserve_w4a8_per_group_gemm"
,
torch
::
kCUDA
,
&
qserve_w4a8_per_group_gemm
);
/*
* From csrc/spatial
*/
m
.
def
(
"create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"
);
m
.
impl
(
"create_greenctx_stream_by_value"
,
&
create_greenctx_stream_by_value
);
}
}
REGISTER_EXTENSION
(
common_ops
)
REGISTER_EXTENSION
(
common_ops
)
sgl-kernel/csrc/spatial/greenctx_stream.cu
View file @
0c8594e6
...
@@ -42,6 +42,7 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
...
@@ -42,6 +42,7 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
// This symbol is introduced in CUDA 12.5
// This symbol is introduced in CUDA 12.5
const
static
auto
pfn
=
probe_cuGreenCtxStreamCreate
();
const
static
auto
pfn
=
probe_cuGreenCtxStreamCreate
();
if
(
!
pfn
)
{
if
(
!
pfn
)
{
TORCH_WARN
(
"cuGreenCtxStreamCreate(cuda>=12.5) is not available, using fallback"
);
return
create_greenctx_stream_fallback
(
gctx
);
return
create_greenctx_stream_fallback
(
gctx
);
}
}
...
@@ -55,17 +56,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
...
@@ -55,17 +56,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
)
{
std
::
vector
<
int64_t
>
create_greenctx_stream_by_value
(
int64_t
smA
,
int64_t
smB
,
int64_t
device
)
{
CUDA_DRV
(
cuDriverGetVersion
(
&
CUDA_DRIVER_VERSION
));
CUDA_DRV
(
cuDriverGetVersion
(
&
CUDA_DRIVER_VERSION
));
if
(
CUDA_DRIVER_VERSION
<
12040
)
{
TORCH_CHECK
(
false
,
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
);
}
CUgreenCtx
gctx
[
3
];
CUgreenCtx
gctx
[
3
];
CUdevResourceDesc
desc
[
3
];
CUdevResourceDesc
desc
[
3
];
CUdevResource
input
;
CUdevResource
input
;
CUdevResource
resources
[
4
];
CUdevResource
resources
[
4
];
if
(
smA
<=
0
||
smB
<=
0
)
{
TORCH_CHECK
(
false
,
"SM counts must be positive"
);
TORCH_CHECK
(
smA
>
0
&&
smB
>
0
,
"SM counts must be positive"
);
}
CUDA_DRV
(
cuDeviceGetDevResource
((
CUdevice
)
device
,
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
CUDA_DRV
(
cuDeviceGetDevResource
((
CUdevice
)
device
,
&
input
,
CU_DEV_RESOURCE_TYPE_SM
));
...
...
sgl-kernel/csrc/spatial_extension.cc
0 → 100644
View file @
0c8594e6
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT
(
sgl_kernel
,
m
)
{
/*
* From csrc/spatial
*/
m
.
def
(
"create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"
);
m
.
impl
(
"create_greenctx_stream_by_value"
,
&
create_greenctx_stream_by_value
);
}
REGISTER_EXTENSION
(
spatial_ops
)
sgl-kernel/python/sgl_kernel/__init__.py
View file @
0c8594e6
...
@@ -92,7 +92,20 @@ from sgl_kernel.sampling import (
...
@@ -92,7 +92,20 @@ from sgl_kernel.sampling import (
top_p_renorm_prob
,
top_p_renorm_prob
,
top_p_sampling_from_probs
,
top_p_sampling_from_probs
,
)
)
from
sgl_kernel.spatial
import
create_greenctx_stream_by_value
,
get_sm_available
def
create_greenctx_stream_by_value
(
*
args
,
**
kwargs
):
from
sgl_kernel.spatial
import
create_greenctx_stream_by_value
as
_impl
return
_impl
(
*
args
,
**
kwargs
)
def
get_sm_available
(
*
args
,
**
kwargs
):
from
sgl_kernel.spatial
import
get_sm_available
as
_impl
return
_impl
(
*
args
,
**
kwargs
)
from
sgl_kernel.speculative
import
(
from
sgl_kernel.speculative
import
(
build_tree_kernel_efficient
,
build_tree_kernel_efficient
,
segment_packbits
,
segment_packbits
,
...
...
sgl-kernel/python/sgl_kernel/spatial.py
View file @
0c8594e6
import
torch
import
torch
from
torch.cuda.streams
import
ExternalStream
from
torch.cuda.streams
import
ExternalStream
try
:
from
.
import
spatial_ops
# triggers TORCH extension registration
except
Exception
as
_e
:
_spatial_import_error
=
_e
else
:
_spatial_import_error
=
None
_IMPORT_ERROR
=
ImportError
(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def
create_greenctx_stream_by_value
(
def
create_greenctx_stream_by_value
(
SM_a
:
int
,
SM_b
:
int
,
device_id
:
int
=
None
SM_a
:
int
,
SM_b
:
int
,
device_id
:
int
=
None
...
@@ -14,11 +25,8 @@ def create_greenctx_stream_by_value(
...
@@ -14,11 +25,8 @@ def create_greenctx_stream_by_value(
Returns:
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
tuple[ExternalStream, ExternalStream]: The two streams.
"""
"""
if
torch
.
version
.
cuda
<
"12.4"
:
if
_spatial_import_error
is
not
None
:
raise
RuntimeError
(
raise
_IMPORT_ERROR
from
_spatial_import_error
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
)
if
device_id
is
None
:
if
device_id
is
None
:
device_id
=
torch
.
cuda
.
current_device
()
device_id
=
torch
.
cuda
.
current_device
()
...
@@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int:
...
@@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int:
Returns:
Returns:
int: The SMs available.
int: The SMs available.
"""
"""
if
_spatial_import_error
is
not
None
:
raise
_IMPORT_ERROR
from
_spatial_import_error
if
device_id
is
None
:
if
device_id
is
None
:
device_id
=
torch
.
cuda
.
current_device
()
device_id
=
torch
.
cuda
.
current_device
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment