Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b712be98
Unverified
Commit
b712be98
authored
Jun 03, 2025
by
Yan Ru Pei
Committed by
GitHub
Jun 03, 2025
Browse files
feat: add data parallel rank to KVEventBatch (#18925)
parent
a8da78ea
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
362 additions
and
86 deletions
+362
-86
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/distributed/conftest.py
tests/distributed/conftest.py
+66
-41
tests/distributed/test_events.py
tests/distributed/test_events.py
+67
-2
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+156
-33
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+68
-9
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+3
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
b712be98
...
@@ -145,6 +145,7 @@ steps:
...
@@ -145,6 +145,7 @@ steps:
-
examples/offline_inference/rlhf_colocate.py
-
examples/offline_inference/rlhf_colocate.py
-
tests/examples/offline_inference/data_parallel.py
-
tests/examples/offline_inference/data_parallel.py
-
tests/v1/test_async_llm_dp.py
-
tests/v1/test_async_llm_dp.py
-
tests/v1/engine/test_engine_core_client.py
commands
:
commands
:
# test with tp=2 and external_dp=2
# test with tp=2 and external_dp=2
-
VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
-
VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
...
@@ -154,6 +155,7 @@ steps:
...
@@ -154,6 +155,7 @@ steps:
# test with internal dp
# test with internal dp
-
python3 ../examples/offline_inference/data_parallel.py
-
python3 ../examples/offline_inference/data_parallel.py
-
TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s distributed/test_pynccl.py
...
...
tests/distributed/conftest.py
View file @
b712be98
...
@@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
...
@@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
from
.test_events
import
SampleBatch
from
.test_events
import
SampleBatch
DP_RANK
=
0
@
pytest
.
fixture
@
pytest
.
fixture
def
random_port
():
def
random_port
():
"""Generate a random port number for testing"""
"""Generate a random port number for testing"""
return
random
.
randint
(
10000
,
600
00
)
return
random
.
randint
(
10000
,
599
00
)
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -30,21 +32,23 @@ def publisher_config(random_port, request):
...
@@ -30,21 +32,23 @@ def publisher_config(random_port, request):
replay_endpoint
=
endpoint
+
"-replay"
replay_endpoint
=
endpoint
+
"-replay"
else
:
else
:
endpoint
=
f
"tcp://*:
{
random_port
}
"
endpoint
=
f
"tcp://*:
{
random_port
}
"
replay_endpoint
=
f
"tcp://*:
{
random_port
+
1
}
"
replay_endpoint
=
f
"tcp://*:
{
random_port
+
1
00
}
"
return
KVEventsConfig
(
enable_kv_cache_events
=
True
,
return
KVEventsConfig
(
publisher
=
"zmq"
,
enable_kv_cache_events
=
True
,
endpoint
=
endpoint
,
publisher
=
"zmq"
,
replay_endpoint
=
replay_endpoint
,
endpoint
=
endpoint
,
buffer_steps
=
100
,
replay_endpoint
=
replay_endpoint
,
hwm
=
1000
,
buffer_steps
=
100
,
topic
=
"test"
)
hwm
=
1000
,
topic
=
"test"
,
)
@
pytest
.
fixture
@
pytest
.
fixture
def
publisher
(
publisher_config
):
def
publisher
(
publisher_config
):
"""Create and return a publisher instance"""
"""Create and return a publisher instance"""
pub
=
EventPublisherFactory
.
create
(
publisher_config
)
pub
=
EventPublisherFactory
.
create
(
publisher_config
,
DP_RANK
)
yield
pub
yield
pub
pub
.
shutdown
()
pub
.
shutdown
()
...
@@ -60,7 +64,11 @@ def subscriber(publisher_config):
...
@@ -60,7 +64,11 @@ def subscriber(publisher_config):
if
replay_endpoint
and
replay_endpoint
.
startswith
(
"tcp://*"
):
if
replay_endpoint
and
replay_endpoint
.
startswith
(
"tcp://*"
):
replay_endpoint
=
replay_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
replay_endpoint
=
replay_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
sub
=
MockSubscriber
(
endpoint
,
replay_endpoint
,
publisher_config
.
topic
)
sub
=
MockSubscriber
(
[
endpoint
],
[
replay_endpoint
]
if
replay_endpoint
else
None
,
publisher_config
.
topic
,
)
yield
sub
yield
sub
sub
.
close
()
sub
.
close
()
...
@@ -68,26 +76,37 @@ def subscriber(publisher_config):
...
@@ -68,26 +76,37 @@ def subscriber(publisher_config):
class
MockSubscriber
:
class
MockSubscriber
:
"""Helper class to receive and verify published events"""
"""Helper class to receive and verify published events"""
def
__init__
(
self
,
def
__init__
(
pub_endpoint
:
str
,
self
,
replay_endpoint
:
Optional
[
str
]
=
None
,
pub_endpoints
:
Union
[
str
,
list
[
str
]],
topic
:
str
=
""
,
replay_endpoints
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
decode_type
=
SampleBatch
):
topic
:
str
=
""
,
decode_type
=
SampleBatch
,
):
self
.
ctx
=
zmq
.
Context
.
instance
()
self
.
ctx
=
zmq
.
Context
.
instance
()
# Set up subscriber socket
# Convert single endpoint to list for consistency
self
.
sub
=
self
.
ctx
.
socket
(
zmq
.
SUB
)
if
isinstance
(
pub_endpoints
,
str
):
self
.
sub
.
setsockopt
(
zmq
.
SUBSCRIBE
,
topic
.
encode
(
'utf-8'
))
pub_endpoints
=
[
pub_endpoints
]
self
.
sub
.
connect
(
pub_endpoint
)
if
isinstance
(
replay_endpoints
,
str
):
replay_endpoints
=
[
replay_endpoints
]
# Set up replay socket if provided
# Set up subscriber socket - connect to all endpoints
self
.
replay
=
None
self
.
sub
=
self
.
ctx
.
socket
(
zmq
.
SUB
)
if
replay_endpoint
:
self
.
sub
.
setsockopt
(
zmq
.
SUBSCRIBE
,
topic
.
encode
(
"utf-8"
))
self
.
replay
=
self
.
ctx
.
socket
(
zmq
.
REQ
)
for
endpoint
in
pub_endpoints
:
self
.
replay
.
connect
(
replay_endpoint
)
self
.
sub
.
connect
(
endpoint
)
# Set up replay sockets if provided
self
.
replay_sockets
=
[]
if
replay_endpoints
:
for
replay_endpoint
in
replay_endpoints
:
replay
=
self
.
ctx
.
socket
(
zmq
.
REQ
)
replay
.
connect
(
replay_endpoint
)
self
.
replay_sockets
.
append
(
replay
)
self
.
topic
=
topic
self
.
topic
=
topic
self
.
topic_bytes
=
topic
.
encode
(
'
utf-8
'
)
self
.
topic_bytes
=
topic
.
encode
(
"
utf-8
"
)
self
.
received_msgs
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
self
.
received_msgs
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
self
.
last_seq
=
-
1
self
.
last_seq
=
-
1
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
type
=
decode_type
)
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
type
=
decode_type
)
...
@@ -107,25 +126,31 @@ class MockSubscriber:
...
@@ -107,25 +126,31 @@ class MockSubscriber:
self
.
received_msgs
.
append
((
seq
,
data
))
self
.
received_msgs
.
append
((
seq
,
data
))
return
seq
,
data
return
seq
,
data
def
request_replay
(
self
,
start_seq
:
int
)
->
None
:
def
request_replay
(
self
,
start_seq
:
int
,
socket_idx
:
int
=
0
)
->
None
:
"""Request replay of messages starting from start_seq"""
"""Request replay of messages starting from start_seq"""
if
not
self
.
replay
:
if
not
self
.
replay_sockets
:
raise
ValueError
(
"Replay socket not initialized"
)
raise
ValueError
(
"Replay sockets not initialized"
)
if
socket_idx
>=
len
(
self
.
replay_sockets
):
self
.
replay
.
send
(
start_seq
.
to_bytes
(
8
,
"big"
))
raise
ValueError
(
f
"Invalid socket index
{
socket_idx
}
"
)
def
receive_replay
(
self
)
->
list
[
tuple
[
int
,
SampleBatch
]]:
self
.
replay_sockets
[
socket_idx
].
send
(
start_seq
.
to_bytes
(
8
,
"big"
))
"""Receive replayed messages"""
if
not
self
.
replay
:
def
receive_replay
(
self
,
raise
ValueError
(
"Replay socket not initialized"
)
socket_idx
:
int
=
0
)
->
list
[
tuple
[
int
,
SampleBatch
]]:
"""Receive replayed messages from a specific replay socket"""
if
not
self
.
replay_sockets
:
raise
ValueError
(
"Replay sockets not initialized"
)
if
socket_idx
>=
len
(
self
.
replay_sockets
):
raise
ValueError
(
f
"Invalid socket index
{
socket_idx
}
"
)
replay_socket
=
self
.
replay_sockets
[
socket_idx
]
replayed
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
replayed
:
list
[
tuple
[
int
,
SampleBatch
]]
=
[]
while
True
:
while
True
:
try
:
try
:
if
not
self
.
replay
.
poll
(
1000
):
if
not
replay
_socket
.
poll
(
1000
):
break
break
frames
=
self
.
replay
.
recv_multipart
()
frames
=
replay
_socket
.
recv_multipart
()
if
not
frames
or
not
frames
[
-
1
]:
if
not
frames
or
not
frames
[
-
1
]:
# End of replay marker
# End of replay marker
break
break
...
@@ -142,5 +167,5 @@ class MockSubscriber:
...
@@ -142,5 +167,5 @@ class MockSubscriber:
def
close
(
self
):
def
close
(
self
):
"""Clean up resources"""
"""Clean up resources"""
self
.
sub
.
close
()
self
.
sub
.
close
()
i
f
self
.
replay
:
f
or
replay
in
self
.
replay
_sockets
:
self
.
replay
.
close
()
replay
.
close
()
tests/distributed/test_events.py
View file @
b712be98
...
@@ -9,6 +9,8 @@ import pytest
...
@@ -9,6 +9,8 @@ import pytest
from
vllm.distributed.kv_events
import
(
EventBatch
,
EventPublisherFactory
,
from
vllm.distributed.kv_events
import
(
EventBatch
,
EventPublisherFactory
,
NullEventPublisher
)
NullEventPublisher
)
DP_RANK
=
0
class
EventSample
(
class
EventSample
(
msgspec
.
Struct
,
msgspec
.
Struct
,
...
@@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
...
@@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
publisher_config
.
replay_endpoint
=
None
publisher_config
.
replay_endpoint
=
None
publisher_config
.
topic
=
"foo"
publisher_config
.
topic
=
"foo"
pub
=
EventPublisherFactory
.
create
(
publisher_config
)
pub
=
EventPublisherFactory
.
create
(
publisher_config
,
DP_RANK
)
from
.conftest
import
MockSubscriber
from
.conftest
import
MockSubscriber
sub_foo
=
MockSubscriber
(
publisher_config
.
endpoint
,
None
,
"foo"
)
sub_foo
=
MockSubscriber
(
publisher_config
.
endpoint
,
None
,
"foo"
)
...
@@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
...
@@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
def
test_null_publisher
():
def
test_null_publisher
():
"""Test that NullEventPublisher can be used without errors"""
"""Test that NullEventPublisher can be used without errors"""
publisher
=
NullEventPublisher
()
publisher
=
NullEventPublisher
(
DP_RANK
)
# This should not raise any errors
# This should not raise any errors
batch
=
create_test_events
(
5
)
batch
=
create_test_events
(
5
)
publisher
.
publish
(
batch
)
publisher
.
publish
(
batch
)
publisher
.
shutdown
()
publisher
.
shutdown
()
def
test_data_parallel_rank_tagging
(
publisher_config
):
"""Test that events are properly tagged with their data parallel rank"""
publisher_config
.
topic
=
"foo"
pub_0
=
EventPublisherFactory
.
create
(
publisher_config
,
DP_RANK
)
pub_1
=
EventPublisherFactory
.
create
(
publisher_config
,
DP_RANK
+
1
)
# Hardcode the expected endpoints based on port offsetting behavior
# Both ranks get offsets according to _offset_endpoint_port function
base_endpoint
=
publisher_config
.
endpoint
if
"tcp://"
in
base_endpoint
:
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0
=
base_endpoint
# rank 0 gets port + 0 = same port
expected_endpoint_1
=
base_endpoint
.
replace
(
":5557"
,
":5558"
)
# rank 1 gets port + 1
else
:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0
=
base_endpoint
# rank 0 gets base
expected_endpoint_1
=
base_endpoint
+
"_dp1"
# rank 1 gets _dp1
from
.conftest
import
MockSubscriber
sub_0
=
MockSubscriber
(
expected_endpoint_0
,
None
,
publisher_config
.
topic
)
sub_1
=
MockSubscriber
(
expected_endpoint_1
,
None
,
publisher_config
.
topic
)
try
:
time
.
sleep
(
0.1
)
# Let publishers start up
# Publish events from different ranks
batch_0
=
create_test_events
(
2
)
batch_1
=
create_test_events
(
3
)
pub_0
.
publish
(
batch_0
)
pub_1
.
publish
(
batch_1
)
# Receive events from rank 0
result_0
=
sub_0
.
receive_one
(
timeout
=
200
)
assert
result_0
is
not
None
,
"No message received from rank 0"
seq_0
,
received_0
=
result_0
# Receive events from rank 1
result_1
=
sub_1
.
receive_one
(
timeout
=
200
)
assert
result_1
is
not
None
,
"No message received from rank 1"
seq_1
,
received_1
=
result_1
# Verify DP rank tagging
assert
received_0
.
data_parallel_rank
==
0
,
(
f
"Expected DP rank 0, got
{
received_0
.
data_parallel_rank
}
"
)
assert
received_1
.
data_parallel_rank
==
1
,
(
f
"Expected DP rank 1, got
{
received_1
.
data_parallel_rank
}
"
)
# Verify event content is correct
assert
len
(
received_0
.
events
)
==
2
,
"Wrong number of events from rank 0"
assert
len
(
received_1
.
events
)
==
3
,
"Wrong number of events from rank 1"
finally
:
pub_0
.
shutdown
()
pub_1
.
shutdown
()
sub_0
.
close
()
sub_1
.
close
()
tests/v1/engine/test_engine_core_client.py
View file @
b712be98
...
@@ -12,8 +12,10 @@ from typing import Optional
...
@@ -12,8 +12,10 @@ from typing import Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
tests.utils
import
multi_gpu_test
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.distributed.kv_events
import
(
BlockStored
,
KVEventBatch
,
ZmqEventPublisher
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
...
@@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS
=
TOKENIZER
(
PROMPT
).
input_ids
PROMPT_TOKENS
=
TOKENIZER
(
PROMPT
).
input_ids
def
make_request
(
params
:
SamplingParams
)
->
EngineCoreRequest
:
def
make_request
(
params
:
SamplingParams
,
prompt_tokens_ids
:
Optional
[
list
[
int
]]
=
None
)
->
EngineCoreRequest
:
if
not
prompt_tokens_ids
:
prompt_tokens_ids
=
PROMPT_TOKENS
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
str
(
uuid
.
uuid4
()),
request_id
=
str
(
uuid
.
uuid4
()),
prompt_token_ids
=
PROMPT_TOKENS
,
prompt_token_ids
=
prompt_tokens_ids
,
mm_inputs
=
None
,
mm_inputs
=
None
,
mm_hashes
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
mm_placeholders
=
None
,
...
@@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
...
@@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
break
break
async
def
loop_until_fully_done_async
(
client
:
EngineCoreClient
,
outputs
:
dict
):
while
True
:
engine_core_outputs
=
(
await
client
.
get_output_async
()).
outputs
if
len
(
engine_core_outputs
)
==
0
:
continue
# Add outputs to the dict
for
out
in
engine_core_outputs
:
outputs
[
out
.
request_id
].
append
(
out
)
# Check if all request IDs in outputs have finished
if
all
(
outs
and
outs
[
-
1
].
finished
for
outs
in
outputs
.
values
()):
break
await
asyncio
.
sleep
(
0.1
)
# Dummy utility function to monkey-patch into engine core.
# Dummy utility function to monkey-patch into engine core.
def
echo
(
self
,
msg
:
str
,
err_msg
:
Optional
[
str
]
=
None
)
->
str
:
def
echo
(
self
,
msg
:
str
,
err_msg
:
Optional
[
str
]
=
None
)
->
str
:
print
(
f
"echo util function called:
{
msg
}
,
{
err_msg
}
"
)
print
(
f
"echo util function called:
{
msg
}
,
{
err_msg
}
"
)
...
@@ -273,10 +299,12 @@ def test_kv_cache_events(
...
@@ -273,10 +299,12 @@ def test_kv_cache_events(
block_size
=
16
block_size
=
16
num_blocks
=
2
num_blocks
=
2
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
engine_args
=
EngineArgs
(
enforce_eager
=
True
,
model
=
MODEL_NAME
,
enable_prefix_caching
=
True
,
enforce_eager
=
True
,
block_size
=
block_size
)
enable_prefix_caching
=
True
,
block_size
=
block_size
,
)
engine_args
.
kv_events_config
=
publisher_config
engine_args
.
kv_events_config
=
publisher_config
vllm_config
=
engine_args
.
create_engine_config
(
vllm_config
=
engine_args
.
create_engine_config
(
...
@@ -297,19 +325,8 @@ def test_kv_cache_events(
...
@@ -297,19 +325,8 @@ def test_kv_cache_events(
try
:
try
:
custom_tokens
=
list
(
range
(
num_blocks
*
block_size
))
custom_tokens
=
list
(
range
(
num_blocks
*
block_size
))
request
=
EngineCoreRequest
(
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
request_id
=
str
(
uuid
.
uuid4
()),
request
=
make_request
(
sampling_params
,
custom_tokens
)
prompt_token_ids
=
custom_tokens
,
mm_inputs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
sampling_params
=
SamplingParams
(
max_tokens
=
1
),
# Short completion for speed
eos_token_id
=
None
,
arrival_time
=
time
.
time
(),
lora_request
=
None
,
cache_salt
=
None
,
)
client
.
add_request
(
request
)
client
.
add_request
(
request
)
outputs
:
dict
[
str
,
list
]
=
{
request
.
request_id
:
[]}
outputs
:
dict
[
str
,
list
]
=
{
request
.
request_id
:
[]}
...
@@ -321,24 +338,130 @@ def test_kv_cache_events(
...
@@ -321,24 +338,130 @@ def test_kv_cache_events(
seq
,
received
=
result
seq
,
received
=
result
assert
seq
==
0
,
"Sequence number mismatch"
assert
seq
==
0
,
"Sequence number mismatch"
assert
len
(
received
.
events
)
==
1
,
(
assert
(
len
(
received
.
events
)
==
1
"We should have exactly one BlockStored event"
)
),
"We should have exactly one BlockStored event"
event
=
received
.
events
[
0
]
event
=
received
.
events
[
0
]
assert
isinstance
(
assert
isinstance
(
event
,
BlockStored
),
(
"We should have a BlockStored event"
)
event
,
BlockStored
),
"We should have a BlockStored event"
assert
len
(
event
.
block_hashes
)
==
num_blocks
,
(
assert
(
len
(
event
.
block_hashes
)
==
num_blocks
"We should have a BlockStored event with 2 block_hashes"
)
),
"We should have a BlockStored event with 2 block_hashes"
assert
event
.
block_size
==
block_size
,
(
assert
(
event
.
block_size
==
block_size
"Block size should be the same as the block size"
)
),
"Block size should be the same as the block size"
assert
event
.
parent_block_hash
is
None
,
(
assert
(
event
.
parent_block_hash
"Parent block hash should be None"
)
is
None
),
"Parent block hash should be None"
assert
event
.
lora_id
is
None
,
"Lora id should be None"
assert
event
.
lora_id
is
None
,
"Lora id should be None"
assert
len
(
event
.
token_ids
)
==
num_blocks
*
block_size
,
(
assert
(
len
(
event
.
token_ids
)
==
num_blocks
*
block_size
"Token ids should be the same as the custom tokens"
)
),
"Token ids should be the same as the custom tokens"
assert
event
.
token_ids
==
custom_tokens
,
(
assert
(
event
.
token_ids
==
custom_tokens
"Token ids should be the same as the custom tokens"
)
),
"Token ids should be the same as the custom tokens"
finally
:
client
.
shutdown
()
subscriber
.
close
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode,publisher_config"
,
[(
True
,
"tcp"
)],
indirect
=
[
"publisher_config"
],
)
@
multi_gpu_test
(
num_gpus
=
4
)
async
def
test_kv_cache_events_dp
(
monkeypatch
:
pytest
.
MonkeyPatch
,
multiprocessing_mode
:
bool
,
publisher_config
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
block_size
=
16
num_blocks
=
2
dp_size
=
2
tp_size
=
2
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
enforce_eager
=
True
,
enable_prefix_caching
=
True
,
data_parallel_size
=
dp_size
,
tensor_parallel_size
=
tp_size
,
block_size
=
block_size
,
)
engine_args
.
kv_events_config
=
publisher_config
vllm_config
=
engine_args
.
create_engine_config
(
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
client
=
EngineCoreClient
.
make_client
(
multiprocess_mode
=
multiprocessing_mode
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
False
,
)
await
asyncio
.
sleep
(
1
)
# Build endpoints for all DP ranks
base_endpoint
=
publisher_config
.
endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
endpoints
=
[]
for
i
in
range
(
dp_size
):
offset_endpoint
=
ZmqEventPublisher
.
offset_endpoint_port
(
base_endpoint
,
i
)
endpoints
.
append
(
offset_endpoint
)
subscriber
=
MockSubscriber
(
endpoints
,
topic
=
publisher_config
.
topic
,
decode_type
=
KVEventBatch
)
try
:
custom_tokens
=
list
(
range
(
num_blocks
*
block_size
))
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
all_request_ids
=
[]
# Create and add 25 requests
# NOTE: attempts to force routing to both dp groups but can be flaky
for
i
in
range
(
25
):
await
asyncio
.
sleep
(
0.01
)
request
=
make_request
(
sampling_params
,
custom_tokens
)
await
client
.
add_request_async
(
request
)
all_request_ids
.
append
(
request
.
request_id
)
await
asyncio
.
sleep
(
0.1
)
# Initialize outputs dict for all requests
outputs
:
dict
[
str
,
list
]
=
{
req_id
:
[]
for
req_id
in
all_request_ids
}
print
(
"processing requests..."
)
await
asyncio
.
wait_for
(
loop_until_fully_done_async
(
client
,
outputs
),
timeout
=
20.0
)
# Receive from subscriber until no more messages
print
(
"collecting results..."
)
results
=
[]
while
True
:
result
=
subscriber
.
receive_one
(
timeout
=
1
)
print
(
result
)
if
result
is
None
:
break
results
.
append
(
result
)
# Collect all events and data_parallel_ranks from all results
all_dp_ranks
=
[
received
.
data_parallel_rank
for
(
_
,
received
)
in
results
]
unique_dps
=
set
(
all_dp_ranks
)
assert
(
len
(
unique_dps
)
==
2
),
f
"Expected 2 unique data_parallel_ranks, got
{
len
(
unique_dps
)
}
"
finally
:
finally
:
client
.
shutdown
()
client
.
shutdown
()
subscriber
.
close
()
@
pytest
.
mark
.
timeout
(
20
)
@
pytest
.
mark
.
timeout
(
20
)
...
...
vllm/distributed/kv_events.py
View file @
b712be98
...
@@ -28,6 +28,7 @@ class EventBatch(
...
@@ -28,6 +28,7 @@ class EventBatch(
):
):
ts
:
float
ts
:
float
events
:
list
[
Any
]
events
:
list
[
Any
]
data_parallel_rank
:
Optional
[
int
]
=
None
class
KVCacheEvent
(
class
KVCacheEvent
(
...
@@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
...
@@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
class
EventPublisher
(
ABC
):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches."""
"""Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def
__init__
(
self
,
data_parallel_rank
:
int
=
0
)
->
None
:
self
.
_data_parallel_rank
=
data_parallel_rank
@
abstractmethod
@
abstractmethod
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
...
@@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
def
__init__
(
def
__init__
(
self
,
self
,
data_parallel_rank
:
int
,
endpoint
:
str
=
"tcp://*:5557"
,
endpoint
:
str
=
"tcp://*:5557"
,
replay_endpoint
:
Optional
[
str
]
=
None
,
replay_endpoint
:
Optional
[
str
]
=
None
,
buffer_steps
:
int
=
10_000
,
buffer_steps
:
int
=
10_000
,
...
@@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
topic
:
str
=
""
,
topic
:
str
=
""
,
)
->
None
:
)
->
None
:
# Storage
# Storage
super
().
__init__
(
data_parallel_rank
)
self
.
_event_queue
=
Queue
[
Optional
[
EventBatch
]](
maxsize
=
max_queue_size
)
self
.
_event_queue
=
Queue
[
Optional
[
EventBatch
]](
maxsize
=
max_queue_size
)
self
.
_buffer
=
deque
[
tuple
[
int
,
bytes
]](
maxlen
=
buffer_steps
)
self
.
_buffer
=
deque
[
tuple
[
int
,
bytes
]](
maxlen
=
buffer_steps
)
...
@@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
self
.
_ctx
=
zmq
.
Context
.
instance
()
self
.
_ctx
=
zmq
.
Context
.
instance
()
self
.
_pub
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_pub
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_replay
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_replay
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_endpoint
=
endpoint
self
.
_dp_rank
=
data_parallel_rank
self
.
_replay_endpoint
=
replay_endpoint
self
.
_endpoint
=
self
.
offset_endpoint_port
(
endpoint
,
self
.
_dp_rank
)
self
.
_replay_endpoint
=
self
.
offset_endpoint_port
(
replay_endpoint
,
self
.
_dp_rank
)
self
.
_hwm
=
hwm
self
.
_hwm
=
hwm
self
.
_socket_setup
()
self
.
_socket_setup
()
...
@@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
if
not
self
.
_running
:
if
not
self
.
_running
:
raise
RuntimeError
(
"Publisher is closed"
)
raise
RuntimeError
(
"Publisher is closed"
)
if
events
.
data_parallel_rank
is
None
:
events
.
data_parallel_rank
=
self
.
_data_parallel_rank
self
.
_event_queue
.
put
(
events
)
self
.
_event_queue
.
put
(
events
)
def
shutdown
(
self
)
->
None
:
def
shutdown
(
self
)
->
None
:
...
@@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
self
.
_pub
.
set_hwm
(
self
.
_hwm
)
self
.
_pub
.
set_hwm
(
self
.
_hwm
)
# Heuristic: bind if wildcard / * present, else connect.
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
# bind stable, connect volatile convention
if
(
"*"
in
self
.
_endpoint
or
"::"
in
self
.
_endpoint
if
(
self
.
_endpoint
is
not
None
or
self
.
_endpoint
.
startswith
(
"ipc://"
)
and
(
"*"
in
self
.
_endpoint
or
"::"
in
self
.
_endpoint
or
self
.
_endpoint
.
startswith
(
"inproc://"
)):
or
self
.
_endpoint
.
startswith
(
"ipc://"
)
or
self
.
_endpoint
.
startswith
(
"inproc://"
))):
self
.
_pub
.
bind
(
self
.
_endpoint
)
self
.
_pub
.
bind
(
self
.
_endpoint
)
el
s
e
:
el
if
self
.
_endpoint
is
not
Non
e
:
self
.
_pub
.
connect
(
self
.
_endpoint
)
self
.
_pub
.
connect
(
self
.
_endpoint
)
# Set up replay socket: use ROUTER
# Set up replay socket: use ROUTER
...
@@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
...
@@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
# receiving payload is (-1, b""")
self
.
_replay
.
send_multipart
((
client_id
,
b
""
,
self
.
END_SEQ
,
b
""
))
self
.
_replay
.
send_multipart
((
client_id
,
b
""
,
self
.
END_SEQ
,
b
""
))
@
staticmethod
def
offset_endpoint_port
(
endpoint
:
Optional
[
str
],
data_parallel_rank
:
int
)
->
Optional
[
str
]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if
not
endpoint
or
data_parallel_rank
==
0
:
return
endpoint
if
"inproc"
in
endpoint
:
return
f
"
{
endpoint
}
_dp
{
data_parallel_rank
}
"
if
"tcp"
in
endpoint
:
if
endpoint
and
":"
in
endpoint
:
# Get everything after the last colon (the port)
last_colon_idx
=
endpoint
.
rfind
(
":"
)
base_addr
=
endpoint
[:
last_colon_idx
]
base_port
=
int
(
endpoint
[
last_colon_idx
+
1
:])
new_port
=
base_port
+
data_parallel_rank
return
f
"
{
base_addr
}
:
{
new_port
}
"
return
endpoint
raise
ValueError
(
"Invalid endpoint: must contain 'inproc' or 'tcp'"
)
class
EventPublisherFactory
:
class
EventPublisherFactory
:
_registry
:
dict
[
str
,
Callable
[...,
EventPublisher
]]
=
{
_registry
:
dict
[
str
,
Callable
[...,
EventPublisher
]]
=
{
...
@@ -281,7 +337,9 @@ class EventPublisherFactory:
...
@@ -281,7 +337,9 @@ class EventPublisherFactory:
cls
.
_registry
[
name
]
=
ctor
cls
.
_registry
[
name
]
=
ctor
@
classmethod
@
classmethod
def
create
(
cls
,
config
:
Optional
[
KVEventsConfig
])
->
EventPublisher
:
def
create
(
cls
,
config
:
Optional
[
KVEventsConfig
],
data_parallel_rank
:
int
=
0
)
->
EventPublisher
:
"""Create publisher from a config mapping."""
"""Create publisher from a config mapping."""
if
not
config
:
if
not
config
:
return
NullEventPublisher
()
return
NullEventPublisher
()
...
@@ -294,4 +352,5 @@ class EventPublisherFactory:
...
@@ -294,4 +352,5 @@ class EventPublisherFactory:
constructor
=
cls
.
_registry
[
kind
]
constructor
=
cls
.
_registry
[
kind
]
except
KeyError
as
exc
:
except
KeyError
as
exc
:
raise
ValueError
(
f
"Unknown event publisher '
{
kind
}
'"
)
from
exc
raise
ValueError
(
f
"Unknown event publisher '
{
kind
}
'"
)
from
exc
return
constructor
(
**
config_dict
)
return
constructor
(
data_parallel_rank
=
data_parallel_rank
,
**
config_dict
)
vllm/v1/core/sched/scheduler.py
View file @
b712be98
...
@@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
...
@@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
)
self
.
kv_events_config
,
vllm_config
.
parallel_config
.
data_parallel_rank
,
)
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
num_gpu_blocks
=
self
.
cache_config
.
num_gpu_blocks
assert
num_gpu_blocks
is
not
None
and
num_gpu_blocks
>
0
assert
num_gpu_blocks
is
not
None
and
num_gpu_blocks
>
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment