"router/vscode:/vscode.git/clone" did not exist on "21267f3ca3f121302b86c1702cc2da6091164c55"
Unverified Commit 11e27d09 authored by IAN's avatar IAN Committed by GitHub
Browse files

[PD]: Support Muti Prefill in one node (#5704)


Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
parent 50eda839
...@@ -137,7 +137,7 @@ class DecodePreallocQueue: ...@@ -137,7 +137,7 @@ class DecodePreallocQueue:
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER) kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
kv_receiver = kv_receiver_class( kv_receiver = kv_receiver_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
) )
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
......
...@@ -6,6 +6,7 @@ import asyncio ...@@ -6,6 +6,7 @@ import asyncio
import random import random
import urllib import urllib
from itertools import chain from itertools import chain
from typing import List
import aiohttp import aiohttp
import orjson import orjson
...@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException ...@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
class PrefillConfig:
def __init__(self, url: str, bootstrap_port: int):
self.url = url
self.bootstrap_port = bootstrap_port
class MiniLoadBalancer: class MiniLoadBalancer:
def __init__(self, prefill_servers, decode_servers): def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
self.prefill_servers = prefill_servers self.prefill_configs = prefill_configs
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers self.decode_servers = decode_servers
def select_pair(self): def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers) prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server
async def generate( async def generate(
self, modified_request, prefill_server, decode_server, endpoint self, modified_request, prefill_server, decode_server, endpoint
...@@ -160,7 +170,7 @@ async def get_model_info(): ...@@ -160,7 +170,7 @@ async def get_model_info():
@app.post("/generate") @app.post("/generate")
async def handle_generate_request(request_data: dict): async def handle_generate_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair() prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data # Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server) parsed_url = urllib.parse.urlparse(prefill_server)
...@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict): ...@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update( modified_request.update(
{ {
"bootstrap_host": [hostname] * batch_size, "bootstrap_host": [hostname] * batch_size,
"bootstrap_port": [bootstrap_port] * batch_size,
"bootstrap_room": [ "bootstrap_room": [
_generate_bootstrap_room() for _ in range(batch_size) _generate_bootstrap_room() for _ in range(batch_size)
], ],
...@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict): ...@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
modified_request.update( modified_request.update(
{ {
"bootstrap_host": hostname, "bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": _generate_bootstrap_room(), "bootstrap_room": _generate_bootstrap_room(),
} }
) )
...@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict): ...@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def handle_completion_request(request_data: dict): async def handle_completion_request(request_data: dict):
prefill_server, decode_server = load_balancer.select_pair() prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
# Parse and transform prefill_server for bootstrap data # Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server) parsed_url = urllib.parse.urlparse(prefill_server)
...@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict): ...@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
modified_request.update( modified_request.update(
{ {
"bootstrap_host": hostname, "bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1), "bootstrap_room": random.randint(0, 2**63 - 1),
} }
) )
...@@ -255,9 +268,9 @@ async def get_models(): ...@@ -255,9 +268,9 @@ async def get_models():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def run(prefill_addrs, decode_addrs, host, port): def run(prefill_configs, decode_addrs, host, port):
global load_balancer global load_balancer
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs) load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
...@@ -268,6 +281,11 @@ if __name__ == "__main__": ...@@ -268,6 +281,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers" "--prefill", required=True, help="Comma-separated URLs for prefill servers"
) )
parser.add_argument(
"--prefill-bootstrap-ports",
help="Comma-separated bootstrap ports for prefill servers",
default="8998",
)
parser.add_argument( parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers" "--decode", required=True, help="Comma-separated URLs for decode servers"
) )
...@@ -278,4 +296,23 @@ if __name__ == "__main__": ...@@ -278,4 +296,23 @@ if __name__ == "__main__":
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)" "--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
) )
args = parser.parse_args() args = parser.parse_args()
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
prefill_urls = args.prefill.split(",")
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
if len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(prefill_urls)
else:
if len(bootstrap_ports) != len(prefill_urls):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
exit(1)
prefill_configs = []
for url, port in zip(prefill_urls, bootstrap_ports):
prefill_configs.append(PrefillConfig(url, port))
decode_addrs = args.decode.split(",")
run(prefill_configs, decode_addrs, args.host, args.port)
...@@ -97,6 +97,7 @@ class GenerateReqInput: ...@@ -97,6 +97,7 @@ class GenerateReqInput:
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
...@@ -400,6 +401,9 @@ class GenerateReqInput: ...@@ -400,6 +401,9 @@ class GenerateReqInput:
bootstrap_host=( bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None self.bootstrap_host[i] if self.bootstrap_host is not None else None
), ),
bootstrap_port=(
self.bootstrap_port[i] if self.bootstrap_port is not None else None
),
bootstrap_room=( bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None self.bootstrap_room[i] if self.bootstrap_room is not None else None
), ),
...@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput: ...@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
# For disaggregated inference # For disaggregated inference
bootstrap_host: Optional[str] = None bootstrap_host: Optional[str] = None
bootstrap_port: Optional[int] = None
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
......
...@@ -391,6 +391,7 @@ class Req: ...@@ -391,6 +391,7 @@ class Req:
return_hidden_states: bool = False, return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None, eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None, bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
): ):
# Input and output info # Input and output info
...@@ -526,6 +527,7 @@ class Req: ...@@ -526,6 +527,7 @@ class Req:
# For disaggregation # For disaggregation
self.bootstrap_host: str = bootstrap_host self.bootstrap_host: str = bootstrap_host
self.bootstrap_port: Optional[int] = bootstrap_port
self.bootstrap_room: Optional[int] = bootstrap_room self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[BaseKVSender] = None self.disagg_kv_sender: Optional[BaseKVSender] = None
......
...@@ -791,6 +791,7 @@ class Scheduler( ...@@ -791,6 +791,7 @@ class Scheduler(
return_hidden_states=recv_req.return_hidden_states, return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id, eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host, bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room, bootstrap_room=recv_req.bootstrap_room,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
...@@ -498,6 +498,7 @@ class TokenizerManager: ...@@ -498,6 +498,7 @@ class TokenizerManager:
token_ids_logprob, token_ids_logprob,
obj.stream, obj.stream,
bootstrap_host=obj.bootstrap_host, bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
bootstrap_room=obj.bootstrap_room, bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path, lora_path=obj.lora_path,
input_embeds=input_embeds, input_embeds=input_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment