mini_lb.py 14.2 KB
Newer Older
Byron Hsu's avatar
Byron Hsu committed
1
"""
Byron Hsu's avatar
Byron Hsu committed
2
Minimal HTTP load balancer for prefill and decode servers for testing.
Byron Hsu's avatar
Byron Hsu committed
3
4
5
"""

import asyncio
6
7
import dataclasses
import logging
Byron Hsu's avatar
Byron Hsu committed
8
9
10
import random
import urllib
from itertools import chain
11
from typing import List, Optional
Byron Hsu's avatar
Byron Hsu committed
12
13
14
15
16
17
18

import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

19
from sglang.srt.disaggregation.utils import PDRegistryRequest
20
from sglang.srt.utils import maybe_wrap_ipv6_address
Byron Hsu's avatar
Byron Hsu committed
21

22
23
24
25
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
    1024 * 64
)  # 64KB, to prevent aiohttp's "Chunk too big" error

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

def setup_logger():
    logger = logging.getLogger("pdlb")
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    return logger


logger = setup_logger()


@dataclasses.dataclass
47
class PrefillConfig:
48
49
    url: str
    bootstrap_port: Optional[int] = None
50
51


Byron Hsu's avatar
Byron Hsu committed
52
class MiniLoadBalancer:
53
54
55
    def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
        self.prefill_configs = prefill_configs
        self.prefill_servers = [p.url for p in prefill_configs]
Byron Hsu's avatar
Byron Hsu committed
56
57
        self.decode_servers = decode_servers

58
59
60
61
62
63
64
    def add_prefill_server(self, new_prefill_config: PrefillConfig):
        self.prefill_configs.append(new_prefill_config)
        self.prefill_servers.append(new_prefill_config.url)

    def add_decode_server(self, new_decode_server: str):
        self.decode_servers.append(new_decode_server)

Byron Hsu's avatar
Byron Hsu committed
65
    def select_pair(self):
66
67
68
69
        # TODO: return some message instead of panic
        assert len(self.prefill_configs) > 0, "No prefill servers available"
        assert len(self.decode_servers) > 0, "No decode servers available"

70
71
72
        prefill_config = random.choice(self.prefill_configs)
        decode_server = random.choice(self.decode_servers)
        return prefill_config.url, prefill_config.bootstrap_port, decode_server
Byron Hsu's avatar
Byron Hsu committed
73

Byron Hsu's avatar
Byron Hsu committed
74
    async def generate(
75
        self, modified_request, prefill_server, decode_server, endpoint
Byron Hsu's avatar
Byron Hsu committed
76
    ) -> ORJSONResponse:
77
        assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
Byron Hsu's avatar
Byron Hsu committed
78

79
80
81
82
83
        async with aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(
                total=3600
            )  # Add timeout for request reliability
        ) as session:
Byron Hsu's avatar
Byron Hsu committed
84
            tasks = [
85
86
                session.post(f"{prefill_server}/{endpoint}", json=modified_request),
                session.post(f"{decode_server}/{endpoint}", json=modified_request),
Byron Hsu's avatar
Byron Hsu committed
87
            ]
88

Byron Hsu's avatar
Byron Hsu committed
89
            # Wait for both responses to complete. Prefill should end first.
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            prefill_response, decode_response = await asyncio.gather(*tasks)

            if "return_logprob" in modified_request:

                prefill_json = await prefill_response.json()
                ret_json = await decode_response.json()

                # merge `meta_info.input_token_logprobs` from prefill to decode
                if "meta_info" in ret_json:
                    if "input_token_logprobs" in ret_json["meta_info"]:
                        ret_json["meta_info"]["input_token_logprobs"] = (
                            prefill_json["meta_info"]["input_token_logprobs"]
                            + ret_json["meta_info"]["input_token_logprobs"]
                        )
            else:
                ret_json = await decode_response.json()
Byron Hsu's avatar
Byron Hsu committed
106

Byron Hsu's avatar
Byron Hsu committed
107
            return ORJSONResponse(
108
                content=ret_json,
Byron Hsu's avatar
Byron Hsu committed
109
110
                status_code=decode_response.status,
            )
Byron Hsu's avatar
Byron Hsu committed
111

112
113
114
115
116
    async def generate_stream(
        self, modified_request, prefill_server, decode_server, endpoint="generate"
    ):
        assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"

Byron Hsu's avatar
Byron Hsu committed
117
118
119
120
121
122
        async def stream_results():
            async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(
                    total=3600
                )  # Add timeout for request reliability
            ) as session:
123
124
                # Create the tasks for both prefill and decode requests
                tasks = [
125
126
                    session.post(f"{prefill_server}/{endpoint}", json=modified_request),
                    session.post(f"{decode_server}/{endpoint}", json=modified_request),
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
                ]
                # Wait for both responses to complete. Since this is streaming, they return immediately.
                prefill_response, decode_response = await asyncio.gather(*tasks)

                if modified_request.get("return_logprob", False):
                    prefill_chunks = []
                    async for chunk in prefill_response.content:
                        prefill_chunks.append(chunk)

                    first_prefill_chunk = (
                        prefill_chunks[0].decode("utf-8")[5:].strip("\n")
                    )
                    first_prefill_chunk_json = orjson.loads(first_prefill_chunk)

                    async for chunk in decode_response.content:
                        # Note: This is inefficient
                        # merge prefill input_token_logprobs, output_token_logprobs to decode
                        decoded_chunk = chunk.decode("utf-8")
                        if (
                            decoded_chunk
                            and decoded_chunk.startswith("data:")
                            and "[DONE]" not in decoded_chunk
                        ):
                            ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
                            ret_json["meta_info"]["input_token_logprobs"] = (
                                first_prefill_chunk_json["meta_info"][
                                    "input_token_logprobs"
                                ]
                                + ret_json["meta_info"]["input_token_logprobs"]
                            )

                            yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
                        else:
                            yield chunk
                else:
162
163
164
                    async for chunk in decode_response.content.iter_chunked(
                        AIOHTTP_STREAM_READ_CHUNK_SIZE
                    ):
Byron Hsu's avatar
Byron Hsu committed
165
                        yield chunk
Byron Hsu's avatar
Byron Hsu committed
166

Byron Hsu's avatar
Byron Hsu committed
167
168
169
170
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
        )
Byron Hsu's avatar
Byron Hsu committed
171
172
173


app = FastAPI()
174
load_balancer: Optional[MiniLoadBalancer] = None
Byron Hsu's avatar
Byron Hsu committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221


@app.get("/health")
async def health_check():
    return Response(status_code=200)


@app.get("/health_generate")
async def health_check():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )
    async with aiohttp.ClientSession() as session:
        # Create the tasks
        tasks = []
        for server in chain(prefill_servers, decode_servers):
            tasks.append(session.post(f"{server}/health_generate"))
        for i, response in enumerate(asyncio.as_completed(tasks)):
            await response
    return Response(status_code=200)


@app.post("/flush_cache")
async def flush_cache():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )
    async with aiohttp.ClientSession() as session:
        # Create the tasks
        tasks = []
        for server in chain(prefill_servers, decode_servers):
            tasks.append(session.post(f"{server}/flush_cache"))
        for i, response in enumerate(asyncio.as_completed(tasks)):
            await response
    return Response(status_code=200)


@app.get("/get_server_info")
async def get_server_info():
    prefill_servers, decode_servers = (
        load_balancer.prefill_servers,
        load_balancer.decode_servers,
    )
    prefill_infos = []
    decode_infos = []
222
223
    all_internal_states = []

Byron Hsu's avatar
Byron Hsu committed
224
225
226
227
228
229
    async with aiohttp.ClientSession() as session:
        for server in chain(prefill_servers):
            server_info = await session.get(f"{server}/get_server_info")
            prefill_infos.append(await server_info.json())
        for server in chain(decode_servers):
            server_info = await session.get(f"{server}/get_server_info")
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            info_json = await server_info.json()
            decode_infos.append(info_json)
            # Extract internal_states from decode servers
            if "internal_states" in info_json:
                all_internal_states.extend(info_json["internal_states"])

    # Return format expected by bench_one_batch_server.py
    if all_internal_states:
        return {
            "internal_states": all_internal_states,
            "prefill": prefill_infos,
            "decode": decode_infos,
        }
    else:
        # Fallback with dummy data if no internal states found
        return {
            "internal_states": [
                {
                    "last_gen_throughput": 0.0,
                    "avg_spec_accept_length": None,
                }
            ],
            "prefill": prefill_infos,
            "decode": decode_infos,
        }
Byron Hsu's avatar
Byron Hsu committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270


@app.get("/get_model_info")
async def get_model_info():
    # Dummy model information
    model_info = {
        "model_path": "/path/to/dummy/model",
        "tokenizer_path": "/path/to/dummy/tokenizer",
        "is_generation": True,
        "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
    }
    return ORJSONResponse(content=model_info)


@app.post("/generate")
async def handle_generate_request(request_data: dict):
271
    prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
Byron Hsu's avatar
Byron Hsu committed
272
273
274

    # Parse and transform prefill_server for bootstrap data
    parsed_url = urllib.parse.urlparse(prefill_server)
275
    hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
Byron Hsu's avatar
Byron Hsu committed
276
    modified_request = request_data.copy()
277
278
279
280
281
282

    batch_size = _get_request_batch_size(modified_request)
    if batch_size is not None:
        modified_request.update(
            {
                "bootstrap_host": [hostname] * batch_size,
283
                "bootstrap_port": [bootstrap_port] * batch_size,
284
285
286
287
288
289
290
291
292
                "bootstrap_room": [
                    _generate_bootstrap_room() for _ in range(batch_size)
                ],
            }
        )
    else:
        modified_request.update(
            {
                "bootstrap_host": hostname,
293
                "bootstrap_port": bootstrap_port,
294
295
296
                "bootstrap_room": _generate_bootstrap_room(),
            }
        )
Byron Hsu's avatar
Byron Hsu committed
297
298

    if request_data.get("stream", False):
Byron Hsu's avatar
Byron Hsu committed
299
        return await load_balancer.generate_stream(
300
            modified_request, prefill_server, decode_server, "generate"
Byron Hsu's avatar
Byron Hsu committed
301
302
303
        )
    else:
        return await load_balancer.generate(
304
            modified_request, prefill_server, decode_server, "generate"
Byron Hsu's avatar
Byron Hsu committed
305
306
307
        )


308
async def _forward_to_backend(request_data: dict, endpoint_name: str):
309
    prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
310
311
312

    # Parse and transform prefill_server for bootstrap data
    parsed_url = urllib.parse.urlparse(prefill_server)
313
    hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
314
315
316
317
    modified_request = request_data.copy()
    modified_request.update(
        {
            "bootstrap_host": hostname,
318
            "bootstrap_port": bootstrap_port,
319
            "bootstrap_room": _generate_bootstrap_room(),
320
321
322
323
324
325
326
327
        }
    )

    if request_data.get("stream", False):
        return await load_balancer.generate_stream(
            modified_request,
            prefill_server,
            decode_server,
328
            endpoint=endpoint_name,
329
330
331
332
333
334
        )
    else:
        return await load_balancer.generate(
            modified_request,
            prefill_server,
            decode_server,
335
            endpoint=endpoint_name,
336
337
338
        )


339
340
341
342
343
344
345
346
347
348
@app.post("/v1/chat/completions")
async def handle_chat_completion_request(request_data: dict):
    return await _forward_to_backend(request_data, "v1/chat/completions")


@app.post("/v1/completions")
async def handle_completion_request(request_data: dict):
    return await _forward_to_backend(request_data, "v1/completions")


349
350
351
352
353
354
355
356
357
358
359
360
361
def _generate_bootstrap_room():
    return random.randint(0, 2**63 - 1)


# We may utilize `GenerateReqInput`'s logic later
def _get_request_batch_size(request):
    if (text := request.get("text")) is not None:
        return None if isinstance(text, str) else len(text)
    if (input_ids := request.get("input_ids")) is not None:
        return None if isinstance(input_ids[0], int) else len(input_ids)
    return None


Byron Hsu's avatar
Byron Hsu committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
@app.get("/v1/models")
async def get_models():
    prefill_server = load_balancer.prefill_servers[0]  # Get the first prefill server
    async with aiohttp.ClientSession() as session:
        try:
            response = await session.get(f"{prefill_server}/v1/models")
            if response.status != 200:
                raise HTTPException(
                    status_code=response.status,
                    detail=f"Prefill server error: Status {response.status}",
                )
            return ORJSONResponse(content=await response.json())
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))


378
379
380
@app.post("/register")
async def register(obj: PDRegistryRequest):
    if obj.mode == "prefill":
381
        load_balancer.add_prefill_server(
382
383
384
385
386
387
            PrefillConfig(obj.registry_url, obj.bootstrap_port)
        )
        logger.info(
            f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
        )
    elif obj.mode == "decode":
388
        load_balancer.add_decode_server(obj.registry_url)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        logger.info(f"Registered decode server: {obj.registry_url}")
    else:
        raise HTTPException(
            status_code=400,
            detail="Invalid mode. Must be either PREFILL or DECODE.",
        )

    logger.info(
        f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
        f"#Decode servers: {len(load_balancer.decode_servers)}"
    )

    return Response(status_code=200)


404
def run(prefill_configs, decode_addrs, host, port):
Byron Hsu's avatar
Byron Hsu committed
405
    global load_balancer
406
    load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
Byron Hsu's avatar
Byron Hsu committed
407
408
409
410
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
Liangsheng Yin's avatar
Liangsheng Yin committed
411
412
    # FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
    from sglang.srt.disaggregation.launch_lb import main
Byron Hsu's avatar
Byron Hsu committed
413

Liangsheng Yin's avatar
Liangsheng Yin committed
414
    main()