# SGL Kernel [Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang
[![License: Apache-2.0](https://img.shields.io/badge/License-Apache--2.0-blue.svg)](https://github.com/sgl-project/sglang/blob/main/LICENSE) [![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel)
SGL Kernel provides optimized compute primitives for the SGLang framework, enabling efficient inference for large language models and vision-language models through custom kernels for operations. ## Installation Requires torch == 2.8.0 ```bash # Latest version pip3 install sgl-kernel --upgrade ``` ## Building from Source Requires - CMake ≥3.31, - Python ≥3.10 - scikit-build-core - ninja(optional) ### Use Makefile to build sgl-kernel ```bash make build ``` ## Contribution ### Steps to add a new kernel: 1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc) 2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h) 3. Create torch extension in [csrc/common_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/common_extension.cc) 4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source 5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) 6. Add test and benchmark ### Development Tips 1. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`: - How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) ```cpp // We need def with schema here for torch.compile m.def( "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, " "int cublas_handle) -> ()"); m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); ``` ### Adapting C++ Native Types for Torch Compatibility Third-party C++ libraries often use int and float, but PyTorch bindings require int64_t and double due to Python's type mapping. Use make_pytorch_shim from sgl_kernel_torch_shim.h to handle conversions automatically: ```cpp // Add type conversion for int -> int64_t template <> struct pytorch_library_compatible_type { using type = int64_t; static int convert_from_type(int64_t arg) { TORCH_CHECK(arg <= std::numeric_limits::max(), "value too large"); TORCH_CHECK(arg >= std::numeric_limits::min(), "value too small"); return arg; } }; ``` ```cpp // Wrap your function m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ``` ### Testing & Benchmarking 1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests), if you need to skip some test, please use `@pytest.mark.skipif` ```python @pytest.mark.skipif( skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." ) ``` 2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) **We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**: Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides: - Reduced CPU overhead impact for more accurate kernel performance measurements - Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results - More realistic performance data on PDL-supported architectures (SM >= 90) 3. Run test suite ## FAQ - Q: Segmentation fault with CUDA 12.6 - A: Update ptxas to 12.8, reference: [segment fault error](https://github.com/Dao-AILab/flash-attention/issues/1453)