Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9eff7ac1
Commit
9eff7ac1
authored
Apr 02, 2026
by
xuxz
Committed by
zhangzbb
Apr 02, 2026
Browse files
[PD][BugFix]修复PD中的spec decoding的kv传输问题
parent
b281794e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
4 deletions
+178
-4
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
...ving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
+155
-0
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
...ributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-1
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+15
-2
No files found.
examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd_mult_mac.py
0 → 100644
View file @
9eff7ac1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
socket
import
threading
import
uuid
import
aiohttp
import
msgpack
import
zmq
from
quart
import
Quart
,
make_response
,
request
count
=
0
prefill_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
decode_instances
:
dict
[
str
,
str
]
=
{}
# http_address: zmq_address
prefill_cv
=
threading
.
Condition
()
decode_cv
=
threading
.
Condition
()
def
_listen_for_register
(
poller
,
router_socket
):
while
True
:
socks
=
dict
(
poller
.
poll
())
if
router_socket
in
socks
:
remote_address
,
message
=
router_socket
.
recv_multipart
()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data
=
msgpack
.
loads
(
message
)
if
data
[
"type"
]
==
"P"
:
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
elif
data
[
"type"
]
==
"D"
:
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_instances
[
data
[
"http_address"
]]
=
data
[
"zmq_address"
]
else
:
print
(
"Unexpected, Received message from %s, data: %s"
,
remote_address
,
data
,
)
def
start_service_discovery
(
hostname
,
port
):
if
not
hostname
:
hostname
=
socket
.
gethostname
()
if
port
==
0
:
raise
ValueError
(
"Port cannot be 0"
)
context
=
zmq
.
Context
()
router_socket
=
context
.
socket
(
zmq
.
ROUTER
)
router_socket
.
bind
(
f
"tcp://
{
hostname
}
:
{
port
}
"
)
poller
=
zmq
.
Poller
()
poller
.
register
(
router_socket
,
zmq
.
POLLIN
)
_listener_thread
=
threading
.
Thread
(
target
=
_listen_for_register
,
args
=
[
poller
,
router_socket
],
daemon
=
True
)
_listener_thread
.
start
()
return
_listener_thread
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
app
=
Quart
(
__name__
)
def
random_uuid
()
->
str
:
return
str
(
uuid
.
uuid4
().
hex
)
async
def
forward_request
(
url
,
data
,
request_id
):
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
"X-Request-Id"
:
request_id
,
}
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
if
True
:
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
else
:
content
=
await
response
.
read
()
yield
content
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
try
:
original_request_data
=
await
request
.
get_json
()
prefill_request
=
original_request_data
.
copy
()
# change max_tokens = 1 to let it only do prefill
prefill_request
[
"max_tokens"
]
=
1
global
count
global
prefill_instances
global
prefill_cv
with
prefill_cv
:
prefill_list
=
list
(
prefill_instances
.
items
())
prefill_addr
,
prefill_zmq_addr
=
prefill_list
[
count
%
len
(
prefill_list
)]
global
decode_instances
global
decode_cv
with
decode_cv
:
decode_list
=
list
(
decode_instances
.
items
())
decode_addr
,
decode_zmq_addr
=
decode_list
[
count
%
len
(
decode_list
)]
print
(
f
"handle_request count:
{
count
}
, [HTTP:
{
prefill_addr
}
, "
f
"ZMQ:
{
prefill_zmq_addr
}
] 👉 [HTTP:
{
decode_addr
}
, "
f
"ZMQ:
{
decode_zmq_addr
}
]"
)
count
+=
1
request_id
=
(
f
"___prefill_addr_
{
prefill_zmq_addr
}
___decode_addr_"
f
"
{
decode_zmq_addr
}
_
{
random_uuid
()
}
"
)
# finish prefill
async
for
_
in
forward_request
(
f
"http://
{
prefill_addr
}
/v1/completions"
,
prefill_request
,
request_id
):
continue
# return decode
generator
=
forward_request
(
f
"http://
{
decode_addr
}
/v1/completions"
,
original_request_data
,
request_id
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
return
response
except
Exception
as
e
:
import
sys
import
traceback
exc_info
=
sys
.
exc_info
()
print
(
"Error occurred in disagg prefill proxy server"
)
print
(
e
)
print
(
""
.
join
(
traceback
.
format_exception
(
*
exc_info
)))
if
__name__
==
"__main__"
:
t
=
start_service_discovery
(
"0.0.0.0"
,
30001
)
app
.
run
(
host
=
"0.0.0.0"
,
port
=
10001
)
t
.
join
()
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_engine.py
View file @
9eff7ac1
...
...
@@ -743,7 +743,7 @@ class DuSwiftEngine:
"pd_pair_id"
:
remote_address
.
pd_pair_id
,
"comm_rank"
:
rank
}
logger
.
info
(
f
"""_send_sync_new:
{
data
}
"""
)
#
logger.info(f"""_send_sync_new:{data}""")
sock
.
send
(
msgpack
.
dumps
(
data
))
response
=
sock
.
recv
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
9eff7ac1
...
...
@@ -3597,6 +3597,7 @@ class GPUModelRunner(
# Run the model.
# Use persistent buffers for CUDA graphs.
clear_kv_metadata
=
self
.
speculative_config
is
None
with
(
set_forward_context
(
attn_metadata
,
...
...
@@ -3610,7 +3611,9 @@ class GPUModelRunner(
skip_compiled
=
has_encoder_input
,
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
self
.
maybe_get_kv_connector_output
(
scheduler_output
,
clear_metadata
=
clear_kv_metadata
)
as
kv_connector_output
,
):
model_output
=
self
.
_model_forward
(
input_ids
=
input_ids
,
...
...
@@ -3826,6 +3829,9 @@ class GPUModelRunner(
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids
(
valid_sampled_token_ids
)
if
self
.
speculative_config
is
not
None
:
self
.
clear_kv_connector_metadata
()
with
record_function_or_nullcontext
(
"gpu_model_runner: eplb"
):
self
.
eplb_step
()
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
9eff7ac1
...
...
@@ -67,9 +67,12 @@ class KVConnectorModelRunnerMixin:
@
staticmethod
def
maybe_get_kv_connector_output
(
scheduler_output
:
"SchedulerOutput"
,
clear_metadata
:
bool
=
True
,
)
->
AbstractContextManager
[
KVConnectorOutput
|
None
]:
return
(
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
)
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
,
clear_metadata
=
clear_metadata
)
if
has_kv_transfer_group
()
else
nullcontext
()
)
...
...
@@ -79,7 +82,9 @@ class KVConnectorModelRunnerMixin:
@
staticmethod
@
contextmanager
def
_get_kv_connector_output
(
scheduler_output
:
"SchedulerOutput"
,
wait_for_save
:
bool
=
True
scheduler_output
:
"SchedulerOutput"
,
wait_for_save
:
bool
=
True
,
clear_metadata
:
bool
=
True
,
)
->
Generator
[
KVConnectorOutput
,
None
,
None
]:
output
=
KVConnectorOutput
()
...
...
@@ -107,7 +112,15 @@ class KVConnectorModelRunnerMixin:
output
.
kv_connector_stats
=
kv_connector
.
get_kv_connector_stats
()
output
.
kv_cache_events
=
kv_connector
.
get_kv_connector_kv_cache_events
()
if
clear_metadata
:
kv_connector
.
clear_connector_metadata
()
@
staticmethod
def
clear_kv_connector_metadata
()
->
None
:
"""Clear the KV connector metadata. Call after draft model runs."""
if
has_kv_transfer_group
():
kv_connector
=
get_kv_transfer_group
()
kv_connector
.
clear_connector_metadata
()
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment