Unverified Commit 8b33d8df authored by wangxiyu191's avatar wangxiyu191 Committed by GitHub
Browse files

[PD] Fix prefill_servers in mini_lb (#6527)

parent e235be16
...@@ -50,6 +50,13 @@ class MiniLoadBalancer: ...@@ -50,6 +50,13 @@ class MiniLoadBalancer:
self.prefill_servers = [p.url for p in prefill_configs] self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers self.decode_servers = decode_servers
def add_prefill_server(self, new_prefill_config: PrefillConfig):
self.prefill_configs.append(new_prefill_config)
self.prefill_servers.append(new_prefill_config.url)
def add_decode_server(self, new_decode_server: str):
self.decode_servers.append(new_decode_server)
def select_pair(self): def select_pair(self):
# TODO: return some message instead of panic # TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available" assert len(self.prefill_configs) > 0, "No prefill servers available"
...@@ -157,7 +164,7 @@ class MiniLoadBalancer: ...@@ -157,7 +164,7 @@ class MiniLoadBalancer:
app = FastAPI() app = FastAPI()
load_balancer = None load_balancer: Optional[MiniLoadBalancer] = None
@app.get("/health") @app.get("/health")
...@@ -331,14 +338,14 @@ async def get_models(): ...@@ -331,14 +338,14 @@ async def get_models():
@app.post("/register") @app.post("/register")
async def register(obj: PDRegistryRequest): async def register(obj: PDRegistryRequest):
if obj.mode == "prefill": if obj.mode == "prefill":
load_balancer.prefill_configs.append( load_balancer.add_prefill_server(
PrefillConfig(obj.registry_url, obj.bootstrap_port) PrefillConfig(obj.registry_url, obj.bootstrap_port)
) )
logger.info( logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}" f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
) )
elif obj.mode == "decode": elif obj.mode == "decode":
load_balancer.decode_servers.append(obj.registry_url) load_balancer.add_decode_server(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}") logger.info(f"Registered decode server: {obj.registry_url}")
else: else:
raise HTTPException( raise HTTPException(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment