README.md 10.3 KB
Newer Older
1
# SGL Kernel
2

3
[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang
4
5

[![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel)
6
7

## Installation
8
For CUDA 12.1 and above:
9
10

```bash
11
pip3 install sgl-kernel
12
13
```

14
For CUDA 11.8:
15
16

```bash
17
pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118
18
```
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
## Build from source

Development build:

```bash
make build
```

Note:

The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.

### Build with [ccache](https://github.com/ccache/ccache)
```bash
# or `yum install -y ccache`.
apt-get install -y ccache
# Building with ccache is enabled when ccache is installed and CCACHE_DIR is set.
export CCACHE_DIR=/path/to/your/ccache/dir
export CCACHE_BACKEND=""
export CCACHE_KEEP_LOCAL_STORAGE="TRUE"
unset CCACHE_READONLY
python -m uv build --wheel -Cbuild-dir=build --color=always .
```

### Configuring CMake Build Options
Cmake options can be configuring by adding `-Ccmake.define.<option>=<value>` to the `uv build` flags.
For example, to enable building FP4 kernels, use:
```bash
python -m uv build --wheel -Cbuild-dir=build -Ccmake.define.SGL_KERNEL_ENABLE_FP4=1 --color=always .
```
See CMakeLists.txt for more options.

### Parallel Build

We highly recommend you build sgl-kernel with Ninja. Ninja can automatically build sgl-kernel in parallel.
55
56
And if you build the sgl-kernel with cmake, you need to add `CMAKE_BUILD_PARALLEL_LEVEL` and limit the
nvcc threads to a single thread by setting `SGL_KERNEL_COMPILE_THREADS=1` for parallel build like:
57
58

```bash
59
60
CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) python -m uv build --wheel -Cbuild-dir=build \
-Ccmake.define.SGL_KERNEL_COMPILE_THREADS=1 --color=always .
61
```
62

63
64
65
66
67
68
69
70
71
### ⚠️ Compilation Issue with `sgl-kernel` and CUDA 12.6

When compiling `sgl-kernel` with FlashAttention on a Hopper GPU using CUDA 12.6, you may encounter a segmentation fault:

```bash
kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu.o
Segmentation fault (core dumped)
```

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
⚠️ **Note**: To ensure that FlashAttention compiles correctly on Hopper GPU Architecture(sm90), it is strongly [recommended](https://github.com/Dao-AILab/flash-attention/issues/1453) to use:
- nvcc version: 12.6
- ptxas version: 12.8

**1. Check Current Versions**

Before proceeding, verify your current CUDA tool versions:
```bash
nvcc --version
ptxas --version
```
**2. Update ptxas to 12.8 (if needed)**

1. Save the following script to a file (e.g., `update_ptxas.sh`).
```bash
#!/usr/bin/env bash
# Source: https://github.com/Dao-AILab/flash-attention/blob/7ff1b621112ba8b538e2fc6a316f2a6b6f22e518/hopper/setup.py#L404
set -ex

if [ -z "$1" ]; then
    echo "Usage: $0 <CUDA_VERSION>"
    exit 1
fi

CUDA_VERSION=$1

if awk "BEGIN {exit !("$CUDA_VERSION" >= 12.6 && "$CUDA_VERSION" < 12.8)}"; then
    NVCC_ARCHIVE_VERSION="12.8.93"
    NVCC_ARCHIVE_NAME="cuda_nvcc-linux-x86_64-${NVCC_ARCHIVE_VERSION}-archive"
    NVCC_ARCHIVE_TAR="${NVCC_ARCHIVE_NAME}.tar.xz"
    NVCC_ARCHIVE_URL="https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/${NVCC_ARCHIVE_TAR}"

    wget "$NVCC_ARCHIVE_URL"
    tar -xf "$NVCC_ARCHIVE_TAR"

    mkdir -p /usr/local/cuda/bin
    cp "${NVCC_ARCHIVE_NAME}/bin/ptxas" /usr/local/cuda/bin/

    # Clean up temporary files
    rm -f "${NVCC_ARCHIVE_TAR}"
    rm -rf "${NVCC_ARCHIVE_NAME}"
fi
```
2. Run the script with your CUDA version as the argument, using `sudo`:
```bash
sudo bash update_ptxas.sh 12.6
# Check the version
ptxas --version
```

122
123
124
125
# Developer Guide

## Development Environment Setup

126
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer_guide/development_guide_using_docker.md#setup-docker-container).
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```

## Project Structure

### Dependencies

Third-party libraries:

- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
142
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
Yineng Zhang's avatar
Yineng Zhang committed
143
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)
144

145
146
### FlashAttention FYI

147
  FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89.
148
149
150

  The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.

151
  And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. That means if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
152

153
154
155
156
### Kernel Development

Steps to add a new kernel:

157
158
1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc)
2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h)
159
3. Create torch extension in [csrc/common_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/common_extension.cc)
160
4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source
161
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
162

163
164
165
166
### Development Tips

1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.

167
168
169
170
171
2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)

- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)

172
   ```cpp
173
174
175
176
177
   // We need def with schema here for torch.compile
   m.def(
    "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
    "cublas_handle, int cuda_stream) -> ()");
   m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
   ```

3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.

    **Avoid this:**

    ```cpp
    torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
        q=query.view(query.shape[0], -1, head_size),
        k=key.view(key.shape[0], -1, head_size),
        q_rope=query.view(query.shape[0], -1, head_size),
        k_rope=key.view(key.shape[0], -1, head_size),
        cos_sin_cache=cos_sin_cache,
        pos_ids=positions.long(),
        interleave=(not is_neox),
        cuda_stream=get_cuda_stream(),
    )
    ```

    **Use this instead:**

    ```cpp
    torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
        query.view(query.shape[0], -1, head_size),
        key.view(key.shape[0], -1, head_size),
        query.view(query.shape[0], -1, head_size),
        key.view(key.shape[0], -1, head_size),
        cos_sin_cache,
        positions.long(),
        (not is_neox),
        get_cuda_stream(),
    )
    ```

212
213
214
215
### Integrating Third-Party Libraries with Data Type Conversion

When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.

216
217
> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.

218
219
220
221
222
223
224
225
226
227
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.

When you need to support new data type conversions, you can easily add conversion functions like this:

```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
  using type = int64_t;
  static int convert_from_type(int64_t arg) {
228
    TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted  to int");
229
230
231
232
233
234
235
236
237
238
239
240
    TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
    return arg;
  }
};
```

To use this with your library functions, simply wrap them with make_pytorch_shim:

```cpp
/*
 * From flash-attention
 */
241
 m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
242
243
```

244
245
### Testing & Benchmarking

246
247
248
249
250
251
252
253
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests), if you need to skip some test, please use `@pytest.mark.skipif`

```python
@pytest.mark.skipif(
    skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
```

254
255
256
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite

257
### FAQ
258

259
- When encountering this error while compiling using ccache: `ImportError: /usr/local/lib/python3.10/dist-packages/sgl_kernel/common_ops.abi3.so: undefined symbol: _ZN3c108ListType3getERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEENS_4Type24SingletonOrSharedTypePtrIS9_EE`, please modify the last command as follows to resolve it: `python3 -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation` .
260

261
262
### Release new version

263
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py)