"docs/vscode:/vscode.git/clone" did not exist on "3961e32390ad16659b561bfd8f1dbd36b874fedf"
Unverified Commit 0c8594e6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optional extension for green context (#9231)

parent c186feed
......@@ -274,7 +274,6 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/spatial/greenctx_stream.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/memory/store.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
......@@ -417,6 +416,18 @@ if (SGL_KERNEL_ENABLE_FA3)
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
endif()
# Build spatial_ops as a separate, optional extension for green contexts
set(SPATIAL_SOURCES
"csrc/spatial/greenctx_stream.cu"
"csrc/spatial_extension.cc"
)
Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES})
target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
......
......@@ -433,12 +433,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
/*
* From csrc/spatial
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
}
REGISTER_EXTENSION(common_ops)
......@@ -42,6 +42,7 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
// This symbol is introduced in CUDA 12.5
const static auto pfn = probe_cuGreenCtxStreamCreate();
if (!pfn) {
TORCH_WARN("cuGreenCtxStreamCreate(cuda>=12.5) is not available, using fallback");
return create_greenctx_stream_fallback(gctx);
}
......@@ -55,17 +56,12 @@ static std::vector<int64_t> create_greenctx_stream_direct_dynamic(CUgreenCtx gct
std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) {
CUDA_DRV(cuDriverGetVersion(&CUDA_DRIVER_VERSION));
if (CUDA_DRIVER_VERSION < 12040) {
TORCH_CHECK(false, "Green Contexts feature requires CUDA Toolkit 12.4 or newer.");
}
CUgreenCtx gctx[3];
CUdevResourceDesc desc[3];
CUdevResource input;
CUdevResource resources[4];
if (smA <= 0 || smB <= 0) {
TORCH_CHECK(false, "SM counts must be positive");
}
TORCH_CHECK(smA > 0 && smB > 0, "SM counts must be positive");
CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM));
......
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/spatial
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
}
REGISTER_EXTENSION(spatial_ops)
......@@ -92,7 +92,20 @@ from sgl_kernel.sampling import (
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.spatial import create_greenctx_stream_by_value, get_sm_available
def create_greenctx_stream_by_value(*args, **kwargs):
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
return _impl(*args, **kwargs)
def get_sm_available(*args, **kwargs):
from sgl_kernel.spatial import get_sm_available as _impl
return _impl(*args, **kwargs)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
segment_packbits,
......
import torch
from torch.cuda.streams import ExternalStream
try:
from . import spatial_ops # triggers TORCH extension registration
except Exception as _e:
_spatial_import_error = _e
else:
_spatial_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
......@@ -14,11 +25,8 @@ def create_greenctx_stream_by_value(
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if torch.version.cuda < "12.4":
raise RuntimeError(
"Green Contexts feature requires CUDA Toolkit 12.4 or newer."
)
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
......@@ -42,6 +50,8 @@ def get_sm_available(device_id: int = None) -> int:
Returns:
int: The SMs available.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment