Unverified Commit 11e27d09 authored by IAN's avatar IAN Committed by GitHub
Browse files

[PD]: Support Muti Prefill in one node (#5704)


Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
parent 50eda839
......@@ -137,7 +137,7 @@ class DecodePreallocQueue:
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
......
......@@ -6,6 +6,7 @@ import asyncio
import random
import urllib
from itertools import chain
from typing import List
import aiohttp
import orjson
......@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
class PrefillConfig:
def __init__(self, url: str, bootstrap_port: int):
self.url = url
self.bootstrap_port = bootstrap_port
class MiniLoadBalancer:
def __init__(self, prefill_servers, decode_servers):
self.prefill_servers = prefill_servers
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
self.prefill_configs = prefill_configs
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers
def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server
async def generate(
self, modified_request, prefill_server, decode_server, endpoint
......@@ -160,7 +170,7 @@ async def get_model_info():
@app.post("/generate")
async def handle_generate_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
......@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size)
],
......@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(),
}
)
......@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
@app.post("/v1/chat/completions")
async def handle_completion_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair()
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
......@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
modified_request.update(
{
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
......@@ -255,9 +268,9 @@ async def get_models():
raise HTTPException(status_code=500, detail=str(e))
def run(prefill_addrs, decode_addrs, host, port):
def run(prefill_configs, decode_addrs, host, port):
global load_balancer
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
uvicorn.run(app, host=host, port=port)
......@@ -268,6 +281,11 @@ if __name__ == "__main__":
parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
help="Comma-separated bootstrap ports for prefill servers",
default="8998",
)
parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers"
)
......@@ -278,4 +296,23 @@ if __name__ == "__main__":
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
prefill_urls = args.prefill.split(",")
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
if len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(prefill_urls)
else:
if len(bootstrap_ports) != len(prefill_urls):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
exit(1)
prefill_configs = []
for url, port in zip(prefill_urls, bootstrap_ports):
prefill_configs.append(PrefillConfig(url, port))
decode_addrs = args.decode.split(",")
run(prefill_configs, decode_addrs, args.host, args.port)
......@@ -97,6 +97,7 @@ class GenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
def normalize_batch_and_arguments(self):
......@@ -400,6 +401,9 @@ class GenerateReqInput:
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
),
bootstrap_port=(
self.bootstrap_port[i] if self.bootstrap_port is not None else None
),
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
......@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None
......
......@@ -391,6 +391,7 @@ class Req:
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None,
):
# Input and output info
......@@ -526,6 +527,7 @@ class Req:
# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: Optional[int] = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None
......
......@@ -791,6 +791,7 @@ class Scheduler(
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
)
req.tokenizer = self.tokenizer
......
......@@ -498,6 +498,7 @@ class TokenizerManager:
token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path,
input_embeds=input_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment