p2p_nccl_connector.md 17.7 KB
Newer Older
1
2
# P2P NCCL Connector

3
4
An implementation of xPyD with dynamic scaling based on point-to-point communication, partly inspired by Dynamo.

5
## Detailed Design
6

7
### Overall Process
8

9
10
11
12
13
14
15
16
As shown in Figure 1, the overall process of this **PD disaggregation** solution is described through a request flow:

1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface.
2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**.
3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**.
4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`.
5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**.
6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**.
17
18
19
20
7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**.

![image1](https://github.com/user-attachments/assets/fb01bde6-755b-49f7-ad45-48a94b1e10a7)

21
### Proxy/Router (Demo)
22
23
24
25
26

A simple HTTP service acts as the entry point for client requests and starts a background thread to listen for P/D instances reporting their HTTP IP and PORT, as well as ZMQ IP and PORT. It maintains a dictionary of `http_addr -> zmq_addr`. The `http_addr` is the IP:PORT for the vLLM instance's request, while the `zmq_addr` is the address for KV cache handshake and metadata reception.

The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request, such as the prompt, and generating a corresponding `request_id`, for example:

27
```text
28
29
30
31
32
33
34
cmpl-___prefill_addr_10.0.1.2:21001___decode_addr_10.0.1.3:22001_93923d63113b4b338973f24d19d4bf11-0
```

Currently, to quickly verify whether xPyD can work, a round-robin selection of 1P1D is used. In the future, it is planned to use a trie combined with the load status of instances to select appropriate P and D.

Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (currently every 3 seconds) to register (i.e., report `http_addr -> zmq_addr`) and keep the connection alive. If an instance crashes and fails to send a ping for a certain period of time, the Proxy/Router will remove the timed-out instance (this feature has not yet been developed).

35
### KV Cache Transfer Methods
36

37
There are three methods for KVCache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVCache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVCache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVCache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVCache from the P instance once it has allocated space for the KVCache.
38
39
40

Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT.

41
### P2P Communication via ZMQ & NCCL
42
43
44

As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart.

45
Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVCache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVCache data itself.
46

47
When a P instance and a D instance transmit KVCache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVCache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVCache transmission can be performed, without being restricted by rank or world size.
48

49
### NCCL Group Topology
50

51
Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVCache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance.
52
53
54
55
56

![image2](https://github.com/user-attachments/assets/837e61d6-365e-4cbf-8640-6dd7ab295b36)

Each NCCL group occupies a certain amount of GPU memory buffer for communication, the size of which is primarily influenced by the `NCCL_MAX_NCHANNELS` environment variable. When `NCCL_MAX_NCHANNELS=16`, an NCCL group typically occupies 100MB, while when `NCCL_MAX_NCHANNELS=8`, it usually takes up 52MB. For large-scale xPyD configurations—such as DeepSeek's 96P144D—this implementation is currently not feasible. Moving forward, we are considering using RDMA for point-to-point communication and are also keeping an eye on UCCL.

57
### GPU Memory Buffer and Tensor Memory Pool
58

59
The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVCache sent by P instances. If it is too large, it will reduce the KVCache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size.
60

61
If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVCache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVCache loss. Once KVCache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance.
62

63
To address the above issues, I have designed and developed a local Tensor memory pool for storing KVCache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVCache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVCache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store.
64

65
## Install vLLM
66

67
68
69
```shell
pip install "vllm>=0.9.2"
```
70

71
## Run xPyD
72

73
### Instructions
74

75
76
77
78
79
80
- The following examples are run on an A800 (80GB) device, using the Meta-Llama-3.1-8B-Instruct model.
- Pay attention to the setting of the `kv_buffer_size` (in bytes). The empirical value is 10% of the GPU memory size. This is related to the kvcache size. If it is too small, the GPU memory buffer for temporarily storing the received kvcache will overflow, causing the kvcache to be stored in the tensor memory pool, which increases latency. If it is too large, the kvcache available for inference will be reduced, leading to a smaller batch size and decreased throughput.
- For Prefill instances, when using non-GET mode, the `kv_buffer_size` can be set to 1, as Prefill currently does not need to receive kvcache. However, when using GET mode, a larger `kv_buffer_size` is required because it needs to store the kvcache sent to the D instance.
- You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict).
- `PUT_ASYNC` offers the best performance and should be prioritized.
- The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`.
81
- The `disagg_proxy_p2p_nccl_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances).
82
83
84
85
- The node running the proxy must have `quart` installed.
- Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`.
- In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**.

86
### Run 1P3D
87

88
#### Proxy (e.g. 10.0.1.1)
89
90

```shell
91
92
cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/
python3 disagg_proxy_p2p_nccl_xpyd.py &
93
94
```

95
#### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
96

97
??? console "Command"
98
99

    ```shell
100
    CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
101
        --host 0.0.0.0 \
102
        --port 20001 \
103
104
105
106
107
108
109
110
111
112
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.9 \
        --kv-transfer-config \
113
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 &
114
    ```
115

116
#### Decode1 (e.g. 10.0.1.3 or 10.0.1.1)
117

118
??? console "Command"
119
120

    ```shell
121
    CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
122
        --host 0.0.0.0 \
123
        --port 20002 \
124
125
126
127
128
129
130
131
132
133
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.7 \
        --kv-transfer-config \
134
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 &
135
    ```
136

137
#### Decode2 (e.g. 10.0.1.4 or 10.0.1.1)
138

139
??? console "Command"
140
141

    ```shell
142
    CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
143
144
145
146
147
148
149
150
151
152
153
154
        --host 0.0.0.0 \
        --port 20003 \
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.7 \
        --kv-transfer-config \
155
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 &
156
    ```
157

158
#### Decode3 (e.g. 10.0.1.5 or 10.0.1.1)
159

160
??? console "Command"
161
162

    ```shell
163
    CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
164
        --host 0.0.0.0 \
165
        --port 20004 \
166
167
168
169
170
171
172
173
174
175
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.7 \
        --kv-transfer-config \
176
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 &
177
    ```
178

179
### Run 3P1D
180

181
#### Proxy (e.g. 10.0.1.1)
182
183

```shell
184
185
cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/
python3 disagg_proxy_p2p_nccl_xpyd.py &
186
187
```

188
#### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
189

190
??? console "Command"
191
192

    ```shell
193
    CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
194
        --host 0.0.0.0 \
195
        --port 20001 \
196
197
198
199
200
201
202
203
204
205
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.9 \
        --kv-transfer-config \
206
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 &
207
    ```
208

209
#### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1)
210

211
??? console "Command"
212
213

    ```shell
214
    CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
215
        --host 0.0.0.0 \
216
        --port 20002 \
217
218
219
220
221
222
223
224
225
226
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.9 \
        --kv-transfer-config \
227
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 &
228
    ```
229

230
#### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1)
231

232
??? console "Command"
233
234

    ```shell
235
    CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
236
237
238
239
240
241
242
243
244
245
246
247
        --host 0.0.0.0 \
        --port 20003 \
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.9 \
        --kv-transfer-config \
248
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 &
249
    ```
250

251
#### Decode1 (e.g. 10.0.1.5 or 10.0.1.1)
252

253
??? console "Command"
254
255

    ```shell
256
    CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
257
        --host 0.0.0.0 \
258
        --port 20004 \
259
260
261
262
263
264
265
266
267
268
        --tensor-parallel-size 1 \
        --seed 1024 \
        --served-model-name base_model \
        --dtype float16 \
        --max-model-len 10000 \
        --max-num-batched-tokens 10000 \
        --max-num-seqs 256 \
        --trust-remote-code \
        --gpu-memory-utilization 0.7 \
        --kv-transfer-config \
269
        '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 &
270
    ```
271

272
## Single request
273
274
275
276
277
278
279
280
281
282
283
284

```shell
curl -X POST -s http://10.0.1.1:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{
    "model": "base_model",
    "prompt": "San Francisco is a",
    "max_tokens": 10,
    "temperature": 0
}'
```

285
## Benchmark
286

287
??? console "Command"
288
289

    ```shell
290
    vllm bench serve \
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        --backend vllm \
        --model base_model \
        --tokenizer meta-llama/Llama-3.1-8B-Instruct \
        --dataset-name "random" \
        --host 10.0.1.1 \
        --port 10001 \
        --random-input-len 1024 \
        --random-output-len 1024 \
        --ignore-eos \
        --burstiness 100 \
        --percentile-metrics "ttft,tpot,itl,e2el" \
        --metric-percentiles "90,95,99" \
        --seed $(date +%s) \
        --trust-remote-code \
        --request-rate 3 \
        --num-prompts 1000
    ```
308

309
## Shut down
310
311
312
313
314

```shell
pgrep python | xargs kill -9 && pkill -f python
```

315
## Test data
316

317
### **Scenario**: 1K input & 200 output tokens, E2E P99 latency ~2s
318
319

![testdata](https://github.com/user-attachments/assets/cef0953b-4567-4bf9-b940-405b92a28eb1)