cpu_server.md 9.19 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
# CPU Servers
2
3

The document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on CPU servers.
Zaili Wang's avatar
Zaili Wang committed
4
SGLang is enabled and optimized on the CPUs equipped with Intel® AMX® Instructions,
5
6
7
8
9
10
which are 4th generation or newer Intel® Xeon® Scalable Processors.

## Optimized Model List

A list of popular LLMs are optimized and run efficiently on CPU,
including the most notable open-source models like Llama series, Qwen series,
Zaili Wang's avatar
Zaili Wang committed
11
and DeepSeek series like DeepSeek-R1 and DeepSeek-V3.1-Terminus.
12

Zaili Wang's avatar
Zaili Wang committed
13
| Model Name | BF16 | W8A8_INT8 | FP8 |
14
15
|:---:|:---:|:---:|:---:|
| DeepSeek-R1 |   | [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8) | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
Zaili Wang's avatar
Zaili Wang committed
16
| DeepSeek-V3.1-Terminus |   | [IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8](https://huggingface.co/IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8) | [deepseek-ai/DeepSeek-V3.1-Terminus](https://huggingface.co/deepseek-ai/DeepSeek-V3.1-Terminus) |
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | [RedHatAI/Llama-3.2-3B-quantized.w8a8](https://huggingface.co/RedHatAI/Llama-3.2-3B-Instruct-quantized.w8a8) |   |
| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | [RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8](https://huggingface.co/RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8) |   |
| QwQ-32B |   | [RedHatAI/QwQ-32B-quantized.w8a8](https://huggingface.co/RedHatAI/QwQ-32B-quantized.w8a8) |   |
| DeepSeek-Distilled-Llama |   | [RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8](https://huggingface.co/RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8) |   |
| Qwen3-235B |   |   | [Qwen/Qwen3-235B-A22B-FP8](https://huggingface.co/Qwen/Qwen3-235B-A22B-FP8) |

**Note:** The model identifiers listed in the table above
have been verified on 6th Gen Intel® Xeon® P-core platforms.

## Installation

### Install Using Docker

It is recommended to use Docker for setting up the SGLang environment.
A [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/Dockerfile.xeon) is provided to facilitate the installation.
Replace `<secret>` below with your [HuggingFace access token](https://huggingface.co/docs/hub/en/security-tokens).

```bash
# Clone the SGLang repository
git clone https://github.com/sgl-project/sglang.git
cd sglang/docker

# Build the docker image
Zaili Wang's avatar
Zaili Wang committed
40
docker build -t sglang-cpu:latest -f Dockerfile.xeon .
41
42
43
44
45
46
47
48
49
50
51

# Initiate a docker container
docker run \
    -it \
    --privileged \
    --ipc=host \
    --network=host \
    -v /dev/shm:/dev/shm \
    -v ~/.cache/huggingface:/root/.cache/huggingface \
    -p 30000:30000 \
    -e "HF_TOKEN=<secret>" \
Zaili Wang's avatar
Zaili Wang committed
52
    sglang-cpu:latest /bin/bash
53
54
55
56
57
58
59
60
61
62
63
64
65
66
```

### Install From Source

If you'd prefer to install SGLang in a bare metal environment,
the command list is as below.
It is worth noting that the environment variable `SGLANG_USE_CPU_ENGINE=1`
is required to enable SGLang service with CPU engine.

```bash
# Create and activate a conda environment
conda create -n sgl-cpu python=3.12 -y
conda activate sgl-cpu

67
# Set PyTorch CPU as primary pip install channel to avoid installing the larger CUDA-enabled version and prevent potential runtime issues.
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
pip config set global.index-url https://download.pytorch.org/whl/cpu
pip config set global.extra-index-url https://pypi.org/simple

# Check if some conda related environment variables have been set
env | grep -i conda
# The following environment variable settings are required
# if they have not been set properly
export CONDA_EXE=$(which conda)
export CONDA_ROOT=${CONDA_EXE}/../..
export CONDA_PREFIX=${CONDA_ROOT}/envs/sgl-cpu
export PATH=${PATH}:${CONDA_ROOT}/bin:${CONDA_ROOT}/condabin

# Clone the SGLang code
git clone https://github.com/sgl-project/sglang.git
cd sglang
git checkout <YOUR-DESIRED-VERSION>

85
# Use dedicated toml file
86
87
cd python
cp pyproject_cpu.toml pyproject.toml
88
89
90
# Install SGLang dependent libs, and build SGLang main package
pip install --upgrade pip setuptools
conda install -y libsqlite==3.48.0 gperftools tbb libnuma numactl
91
pip install .
92
pip install torch==2.7.1 torchvision==0.22.1 triton==3.3.1 --force-reinstall
93
94

# Build the CPU backend kernels
95
cd ../sgl-kernel
96
cp pyproject_cpu.toml pyproject.toml
97
pip install .
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

# Other required environment variables
# Recommend to set these in ~/.bashrc in order not to set every time in a new terminal
export SGLANG_USE_CPU_ENGINE=1
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libiomp5.so:${CONDA_PREFIX}/lib/libtcmalloc.so:${CONDA_PREFIX}/lib/libtbbmalloc.so.2
```

## Launch of the Serving Engine

Example command to launch SGLang serving:

```bash
python -m sglang.launch_server   \
    --model <MODEL_ID_OR_PATH>   \
    --trust-remote-code          \
    --disable-overlap-schedule   \
    --device cpu                 \
    --host 0.0.0.0               \
    --tp 6
```

Notes:

1. For running W8A8 quantized models, please add the flag `--quantization w8a8_int8`.

2. The flag `--tp 6` specifies that tensor parallelism will be applied using 6 ranks (TP6).
    The number of TP specified is how many TP ranks will be used during the execution.
Zaili Wang's avatar
Zaili Wang committed
125
126
127
    On a CPU platform, a TP rank means a sub-NUMA cluster (SNC).
    Usually we can get the SNC information (How many available) from the Operating System.
    Users can specify TP to be no more than the total available SNCs in current system.
128
129
130
131
132
133
134
135
136
137
138
139
140

    If the specified TP rank number differs from the total SNC count,
    the system will automatically utilize the first `n` SNCs.
    Note that `n` cannot exceed the total SNC number, doing so will result in an error.

    To specify the cores to be used, we need to explicitly set the environment variable `SGLANG_CPU_OMP_THREADS_BIND`.
    For example, if we want to run the SGLang service using the first 40 cores of each SNC on a Xeon® 6980P server,
    which has 43-43-42 cores on the 3 SNCs of a socket, we should set:

    ```bash
    export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
    ```

141
142
143
144
    Please beware that with SGLANG_CPU_OMP_THREADS_BIND set,
    the available memory amounts of the ranks may not be determined in prior.
    You may need to set proper `--max-total-tokens` to avoid the out-of-memory error.

145
3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.
146
147
148
149
    To specify the maximum batch size when using `torch.compile`, set the flag `--torch-compile-max-bs`.
    For example, `--enable-torch-compile --torch-compile-max-bs 4` means using `torch.compile`
    and setting the maximum batch size to 4. Currently the maximum applicable batch size
    for optimizing with `torch.compile` is 16.
150
151

4. A warmup step is automatically triggered when the service is started.
152
    The server is ready when you see the log `The server is fired up and ready to roll!`.
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

## Benchmarking with Requests

You can benchmark the performance via the `bench_serving` script.
Run the command in another terminal.

```bash
python -m sglang.bench_serving   \
    --dataset-name random        \
    --random-input-len 1024      \
    --random-output-len 1024     \
    --num-prompts 1              \
    --request-rate inf           \
    --random-range-ratio 1.0
```

The detail explanations of the parameters can be looked up by the command:

```bash
python -m sglang.bench_serving -h
```

Additionally, the requests can be formed with
176
[OpenAI Completions API](https://docs.sglang.ai/basic_usage/openai_api_completions.html)
177
178
and sent via the command line (e.g. using `curl`) or via your own script.

Zaili Wang's avatar
Zaili Wang committed
179
## Example: Running DeepSeek-V3.1-Terminus
180

Zaili Wang's avatar
Zaili Wang committed
181
An example command to launch service for W8A8_INT8 DeepSeek-V3.1-Terminus on a Xeon® 6980P server:
182
183

```bash
Zaili Wang's avatar
Zaili Wang committed
184
185
186
187
188
189
190
191
192
193
python -m sglang.launch_server                                   \
    --model IntervitensInc/DeepSeek-V3.1-Terminus-Channel-int8   \
    --trust-remote-code                                          \
    --disable-overlap-schedule                                   \
    --device cpu                                                 \
    --quantization w8a8_int8                                     \
    --host 0.0.0.0                                               \
    --mem-fraction-static 0.8                                    \
    --enable-torch-compile                                       \
    --torch-compile-max-bs 4                                     \
194
195
196
    --tp 6
```

Zaili Wang's avatar
Zaili Wang committed
197
Similarly, an example command to launch service for FP8 DeepSeek-V3.1-Terminus would be:
198
199
200

```bash
python -m sglang.launch_server                 \
Zaili Wang's avatar
Zaili Wang committed
201
    --model deepseek-ai/DeepSeek-V3.1-Terminus \
202
203
204
205
206
    --trust-remote-code                        \
    --disable-overlap-schedule                 \
    --device cpu                               \
    --host 0.0.0.0                             \
    --mem-fraction-static 0.8                  \
207
208
    --enable-torch-compile                     \
    --torch-compile-max-bs 4                   \
209
210
211
    --tp 6
```

212
213
214
Note: Please set `--torch-compile-max-bs` to the maximum desired batch size for your deployment,
which can be up to 16. The value `4` in the examples is illustrative.

215
216
Then you can test with `bench_serving` command or construct your own command or script
following [the benchmarking example](#benchmarking-with-requests).