Unverified Commit cefa5281 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[ROCm][P/D][MORI][BugFix] Ensure correct api is used when making requests to...


[ROCm][P/D][MORI][BugFix] Ensure correct api is used when making requests to prefill / decode nodes (#39835)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 46794958
......@@ -12,7 +12,7 @@ import aiohttp
import msgpack
import regex as re
import zmq
from quart import Quart, make_response, request
from quart import Quart, Request, make_response, request
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOConstants,
......@@ -139,10 +139,13 @@ async def send_request_to_prefill(
return await response.json()
else:
raise RuntimeError(
"send_request_to_prefill response.status != 200response.status = ",
response.status,
error_message = (
f"send_request_to_prefill response ={response},"
f"reason={response.reason}, status={response.status},"
f"method={response.method}, url={response.url},"
f"real_url={response.real_url}"
)
raise RuntimeError(error_message)
async def start_decode_request(endpoint, req_data, request_id):
......@@ -163,9 +166,13 @@ async def stream_decode_response(session, response, request_id):
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
error_message = (
f"stream_decode_response response ={response},"
f"reason={response.reason}, status={response.status},"
f"method={response.method}, url={response.url},"
f"real_url={response.real_url}"
)
raise RuntimeError(error_message)
finally:
await session.close()
......@@ -175,8 +182,16 @@ def example_round_robin_dp_loader(request_number, dp_size):
@app.route("/v1/completions", methods=["POST"])
async def handle_completions_request():
return await handle_request("/completions", request)
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request():
async def handle_chat_completions_request():
return await handle_request("/chat/completions", request)
async def handle_request(api: str, request: Request):
try:
with _list_lock:
global request_nums
......@@ -230,9 +245,10 @@ async def handle_request():
)
req_data_to_prefill["kv_transfer_params"]["transfer_id"] = transfer_id
prefill_request_url = prefill_instance_endpoint["request_address"] + api
send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
prefill_request_url,
req_data_to_prefill,
request_id,
decode_instance_endpoint,
......@@ -241,7 +257,7 @@ async def handle_request():
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
ip, port = extract_ip_port_fast(prefill_request_url)
req_data["max_tokens"] -= 1
......@@ -276,10 +292,9 @@ async def handle_request():
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["transfer_id"] = transfer_id
decode_request_url = decode_instance_endpoint["request_address"] + api
decode_request_task = asyncio.create_task(
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
start_decode_request(decode_request_url, req_data, request_id)
)
session, decode_response = await decode_request_task
......
......@@ -846,7 +846,7 @@ class MoRIIOConnectorWorker:
]
def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1/completions"
http_request_address = f"http://{self.request_address}/v1"
role = "P" if self.is_producer else "D"
retry_count = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment