Unverified Commit 8bf6d7f4 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

support cmake for sgl-kernel (#4706)


Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: default avataryinfan98 <1106310035@qq.com>
parent 1b9175cb
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
run: | run: |
docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
docker exec ci_sglang pip install -e "python[dev_hip]" docker exec ci_sglang pip install -e "python[dev_hip]"
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
...@@ -87,7 +87,7 @@ jobs: ...@@ -87,7 +87,7 @@ jobs:
run: | run: |
docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
docker exec ci_sglang pip install -e "python[dev_hip]" docker exec ci_sglang pip install -e "python[dev_hip]"
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
......
...@@ -84,6 +84,7 @@ jobs: ...@@ -84,6 +84,7 @@ jobs:
pip3 uninstall sgl-kernel -y || true pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel pip3 list | grep sgl-kernel
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop
- name: Run test - name: Run test
timeout-minutes: 30 timeout-minutes: 30
...@@ -115,6 +116,7 @@ jobs: ...@@ -115,6 +116,7 @@ jobs:
pip3 uninstall sgl-kernel -y || true pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel pip3 list | grep sgl-kernel
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop
- name: Run test - name: Run test
timeout-minutes: 30 timeout-minutes: 30
......
...@@ -29,6 +29,8 @@ RUN git clone ${SGL_REPO} \ ...@@ -29,6 +29,8 @@ RUN git clone ${SGL_REPO} \
git checkout ${SGL_BRANCH}; \ git checkout ${SGL_BRANCH}; \
fi \ fi \
&& cd sgl-kernel \ && cd sgl-kernel \
&& rm -f pyproject.toml \
&& mv pyproject_rocm.toml pyproject.toml \
&& python setup_rocm.py install \ && python setup_rocm.py install \
&& cd .. \ && cd .. \
&& if [ "$BUILD_TYPE" = "srt" ]; then \ && if [ "$BUILD_TYPE" = "srt" ]; then \
......
Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 Subproject commit 79fd1ae90d9b8098ca70dec6071da96f3f6da7b9
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)
# we only want to download 3rd, but not build them.
# FetchContent_MakeAvailable will build it.
cmake_policy(SET CMP0169 OLD)
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule REQUIRED)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
endif()
find_package(Torch REQUIRED)
include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG 62750a2b75c802660e4894434dc55e839f322277
GIT_SHALLOW ON
)
FetchContent_Populate(repo-cutlass)
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
GIT_TAG c57699ac933a93651c34d365797c2d8b41a4765b
GIT_SHALLOW ON
)
FetchContent_Populate(repo-deepgemm)
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer
GIT_TAG 79fd1ae90d9b8098ca70dec6071da96f3f6da7b9
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
include_directories(
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(SGL_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_75,code=sm_75"
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
endif()
if (SGL_KERNEL_ENABLE_BF16)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_BF16"
)
endif()
if (SGL_KERNEL_ENABLE_FP8)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_FP8"
"-DFLASHINFER_ENABLE_FP8_E4M3"
"-DFLASHINFER_ENABLE_FP8_E5M2"
)
endif()
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES
"csrc/allreduce/trt_reduce_internal.cu"
"csrc/allreduce/trt_reduce_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/cublas_grouped_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/nvfp4_quant_entry.cu"
"csrc/gemm/nvfp4_quant_kernels.cu"
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
"csrc/gemm/per_tensor_quant_fp8.cu"
"csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu"
"csrc/torch_extension.cc"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
)
# Support abi3 for build
Python_add_library(common_ops MODULE USE_SABI 3.9 WITH_SOABI ${SOURCES})
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
...@@ -19,13 +19,14 @@ submodule: ## Initialize and update git submodules ...@@ -19,13 +19,14 @@ submodule: ## Initialize and update git submodules
@git submodule update --init --recursive @git submodule update --init --recursive
ln: submodule ## Create compilation database ln: submodule ## Create compilation database
@rm -rf build && bear python3 setup.py build @rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES
install: submodule ## Install package in development mode install: submodule ## Install package in development mode
@pip install -e . @pip install -e .
build: submodule ## Build and install wheel package build: submodule ## Build and install wheel package
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && uv build --wheel -Cbuild-dir=build . --verbose --color=always && pip3 install dist/*whl --force-reinstall --no-deps
clean: ## Remove build artifacts clean: ## Remove build artifacts
@rm -rf build dist *.egg-info @rm -rf build dist *.egg-info
......
...@@ -15,7 +15,7 @@ docker run --rm \ ...@@ -15,7 +15,7 @@ docker run --rm \
pytorch/manylinux-builder:cuda${CUDA_VERSION} \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \
bash -c " bash -c "
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy && \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \ export CUDA_VERSION=${CUDA_VERSION} && \
export SGL_KERNEL_ENABLE_BF16=1 && \ export SGL_KERNEL_ENABLE_BF16=1 && \
...@@ -25,5 +25,5 @@ docker run --rm \ ...@@ -25,5 +25,5 @@ docker run --rm \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
cd /sgl-kernel && \ cd /sgl-kernel && \
ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \ ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always
" "
...@@ -18,7 +18,7 @@ limitations under the License. ...@@ -18,7 +18,7 @@ limitations under the License.
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/all.h>
#define THREADS_PER_BLOCK 128 #define THREADS_PER_BLOCK 128
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/all.h> #include <torch/all.h>
#include <torch/extension.h>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
......
...@@ -16,7 +16,6 @@ limitations under the License. ...@@ -16,7 +16,6 @@ limitations under the License.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h> #include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernel_ops.h" #include "sgl_kernel_ops.h"
...@@ -178,9 +178,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -178,9 +178,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def( m.def(
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " "top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()"); "cuda_stream) -> ()");
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def( m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
......
...@@ -15,8 +15,11 @@ limitations under the License. ...@@ -15,8 +15,11 @@ limitations under the License.
#pragma once #pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <Python.h> #include <Python.h>
#include <torch/extension.h> #include <torch/library.h>
#include <torch/torch.h>
#include <vector> #include <vector>
...@@ -253,23 +256,12 @@ void min_p_sampling_from_probs( ...@@ -253,23 +256,12 @@ void min_p_sampling_from_probs(
double min_p_val, double min_p_val,
bool deterministic, bool deterministic,
int64_t cuda_stream); int64_t cuda_stream);
// top k renorm probs
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
void top_k_renorm_probs( void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
unsigned int top_k_val,
int64_t cuda_stream);
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
inline void top_k_renorm_probs_wrapper(
at::Tensor probs, at::Tensor probs,
at::Tensor renorm_probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val, int64_t top_k_val,
int64_t cuda_stream) { int64_t cuda_stream);
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
}
void top_p_renorm_probs( void top_p_renorm_probs(
at::Tensor probs, at::Tensor probs,
at::Tensor renorm_probs, at::Tensor renorm_probs,
......
...@@ -15,13 +15,189 @@ limitations under the License. ...@@ -15,13 +15,189 @@ limitations under the License.
#pragma once #pragma once
#include <ATen/Tensor.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/all.h>
#include <sstream>
#ifndef USE_ROCM #ifndef USE_ROCM
#include <pytorch_extension_utils.h> // Adapt from FlashInfer
#ifdef FLASHINFER_ENABLE_F16
#define _DISPATCH_CASE_F16(c_type, ...) \
case at::ScalarType::Half: { \
using c_type = nv_half; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_F16(c_type, ...)
#endif #endif
#include <torch/extension.h>
#include <sstream> #ifdef FLASHINFER_ENABLE_BF16
#define _DISPATCH_CASE_BF16(c_type, ...) \
case at::ScalarType::BFloat16: { \
using c_type = nv_bfloat16; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_BF16(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
case at::ScalarType::Float8_e4m3fn: { \
using c_type = __nv_fp8_e4m3; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
#endif
#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
case at::ScalarType::Float8_e5m2: { \
using c_type = __nv_fp8_e5m2; \
return __VA_ARGS__(); \
}
#else
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
#endif
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH(var_name, cond, ...) \
[&]() -> bool { \
switch (cond) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
[&]() -> bool { \
switch (pack_u16(cond1, cond2)) { \
__VA_ARGS__ \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
<< int(cond2) << ")"; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
case pack_u16(case_expr1, case_expr2): { \
constexpr auto case_var1 = case_expr1; \
constexpr auto case_var2 = case_expr2; \
return __VA_ARGS__(); \
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
TORCH_CHECK( \
num_qo_heads % num_kv_heads == 0, \
"num_qo_heads(", \
num_qo_heads, \
") must be divisible by num_kv_heads(", \
num_kv_heads, \
")")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CUDA(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
}
#endif
struct cuda_error : public std::runtime_error { struct cuda_error : public std::runtime_error {
/** /**
......
[build-system] [build-system]
requires = [ requires = [
"setuptools>=75.0",
"scikit-build-core>=0.10", "scikit-build-core>=0.10",
"torch==2.5.1", "torch==2.5.1",
"wheel", "wheel",
] ]
build-backend = "setuptools.build_meta" build-backend = "scikit_build_core.build"
[project] [project]
name = "sgl-kernel" name = "sgl-kernel"
...@@ -30,3 +29,11 @@ exclude = [ ...@@ -30,3 +29,11 @@ exclude = [
"dist*", "dist*",
"tests*", "tests*",
] ]
[tool.scikit-build]
cmake.build-type = "Release"
minimum-version = "build-system.requires"
wheel.py-api = "cp39"
wheel.license-files = []
wheel.packages = ["python/sgl_kernel"]
[build-system]
requires = [
"setuptools>=75.0",
"scikit-build-core>=0.10",
"torch==2.5.1",
"wheel",
]
build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
version = "0.0.5.post3"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA"
]
dependencies = []
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.wheel]
exclude = [
"dist*",
"tests*",
]
...@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal( ...@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper( torch.ops.sgl_kernel.top_k_renorm_probs(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_k_arr, maybe_top_k_arr,
......
#!/usr/bin/env bash
set -ex
WHEEL_DIR="dist"
wheel_files=($WHEEL_DIR/*.whl)
for wheel in "${wheel_files[@]}"; do
new_wheel="${wheel/linux/manylinux2014}"
if [[ "$wheel" != "$new_wheel" ]]; then
echo "Renaming $wheel to $new_wheel"
mv -- "$wheel" "$new_wheel"
fi
done
echo "Wheel renaming completed."
...@@ -21,9 +21,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension ...@@ -21,9 +21,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
root = Path(__file__).parent.resolve() root = Path(__file__).parent.resolve()
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
def _get_version(): def _get_version():
with open(root / "pyproject.toml") as f: with open(root / "pyproject.toml") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment