Commit 9eff7ac1 authored by xuxz's avatar xuxz Committed by zhangzbb
Browse files

[PD][BugFix]修复PD中的spec decoding的kv传输问题

parent b281794e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()
...@@ -743,7 +743,7 @@ class DuSwiftEngine: ...@@ -743,7 +743,7 @@ class DuSwiftEngine:
"pd_pair_id": remote_address.pd_pair_id, "pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank "comm_rank": rank
} }
logger.info(f"""_send_sync_new:{data}""") #logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
response = sock.recv() response = sock.recv()
......
...@@ -3597,6 +3597,7 @@ class GPUModelRunner( ...@@ -3597,6 +3597,7 @@ class GPUModelRunner(
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
clear_kv_metadata = self.speculative_config is None
with ( with (
set_forward_context( set_forward_context(
attn_metadata, attn_metadata,
...@@ -3610,7 +3611,9 @@ class GPUModelRunner( ...@@ -3610,7 +3611,9 @@ class GPUModelRunner(
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, self.maybe_get_kv_connector_output(
scheduler_output, clear_metadata=clear_kv_metadata
) as kv_connector_output,
): ):
model_output = self._model_forward( model_output = self._model_forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -3826,6 +3829,9 @@ class GPUModelRunner( ...@@ -3826,6 +3829,9 @@ class GPUModelRunner(
# ngram and other speculative decoding methods use the sampled # ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping. # tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids) propose_draft_token_ids(valid_sampled_token_ids)
if self.speculative_config is not None:
self.clear_kv_connector_metadata()
with record_function_or_nullcontext("gpu_model_runner: eplb"): with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step() self.eplb_step()
......
...@@ -67,9 +67,12 @@ class KVConnectorModelRunnerMixin: ...@@ -67,9 +67,12 @@ class KVConnectorModelRunnerMixin:
@staticmethod @staticmethod
def maybe_get_kv_connector_output( def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
clear_metadata: bool = True,
) -> AbstractContextManager[KVConnectorOutput | None]: ) -> AbstractContextManager[KVConnectorOutput | None]:
return ( return (
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, clear_metadata=clear_metadata
)
if has_kv_transfer_group() if has_kv_transfer_group()
else nullcontext() else nullcontext()
) )
...@@ -79,7 +82,9 @@ class KVConnectorModelRunnerMixin: ...@@ -79,7 +82,9 @@ class KVConnectorModelRunnerMixin:
@staticmethod @staticmethod
@contextmanager @contextmanager
def _get_kv_connector_output( def _get_kv_connector_output(
scheduler_output: "SchedulerOutput", wait_for_save: bool = True scheduler_output: "SchedulerOutput",
wait_for_save: bool = True,
clear_metadata: bool = True,
) -> Generator[KVConnectorOutput, None, None]: ) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput() output = KVConnectorOutput()
...@@ -107,7 +112,15 @@ class KVConnectorModelRunnerMixin: ...@@ -107,7 +112,15 @@ class KVConnectorModelRunnerMixin:
output.kv_connector_stats = kv_connector.get_kv_connector_stats() output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
kv_connector.clear_connector_metadata()
@staticmethod
def clear_kv_connector_metadata() -> None:
"""Clear the KV connector metadata. Call after draft model runs."""
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
kv_connector.clear_connector_metadata() kv_connector.clear_connector_metadata()
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment