Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f0f84975
Unverified
Commit
f0f84975
authored
Jun 04, 2025
by
ishandhanani
Committed by
GitHub
Jun 04, 2025
Browse files
feat: add dp-rank to KV events (#6852)
parent
3f1e4339
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
195 additions
and
8 deletions
+195
-8
python/sglang/srt/disaggregation/kv_events.py
python/sglang/srt/disaggregation/kv_events.py
+60
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-2
test/srt/test_kv_events.py
test/srt/test_kv_events.py
+131
-1
No files found.
python/sglang/srt/disaggregation/kv_events.py
View file @
f0f84975
...
...
@@ -43,6 +43,7 @@ class EventBatch(
):
ts
:
float
events
:
list
[
Any
]
attn_dp_rank
:
Optional
[
int
]
=
None
class
KVCacheEvent
(
...
...
@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches."""
"""
Lightweight publisher for EventBatch batches with
support for DP attention.
In DP attention - each rank has its own Scheduler and
KV cache instance in order to avoid duplicate events
and ensure proper event attribution. In our implementation
- Each DP rank has its own EventPublisher
- Publishers annotate events with the dp rank
- This allows consumers to distinguish events from different DP ranks
"""
def
__init__
(
self
,
attn_dp_rank
:
int
=
0
):
self
.
_attn_dp_rank
=
attn_dp_rank
@
abstractmethod
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
...
...
@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
def
__init__
(
self
,
attn_dp_rank
:
int
,
endpoint
:
str
=
"tcp://*:5557"
,
replay_endpoint
:
Optional
[
str
]
=
None
,
buffer_steps
:
int
=
10_000
,
...
...
@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
topic
:
str
=
""
,
)
->
None
:
# Storage
super
().
__init__
(
attn_dp_rank
)
self
.
_event_queue
=
Queue
[
Optional
[
EventBatch
]](
maxsize
=
max_queue_size
)
self
.
_buffer
=
deque
[
tuple
[
int
,
bytes
]](
maxlen
=
buffer_steps
)
...
...
@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
self
.
_ctx
=
zmq
.
Context
.
instance
()
self
.
_pub
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_replay
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_endpoint
=
endpoint
self
.
_replay_endpoint
=
replay_endpoint
self
.
_dp_rank
=
attn_dp_rank
self
.
_endpoint
=
self
.
offset_endpoint_port
(
endpoint
,
self
.
_dp_rank
)
self
.
_replay_endpoint
=
self
.
offset_endpoint_port
(
replay_endpoint
,
self
.
_dp_rank
)
self
.
_hwm
=
hwm
self
.
_socket_setup
()
...
...
@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
if
not
self
.
_running
:
raise
RuntimeError
(
"Publisher is closed"
)
if
events
.
attn_dp_rank
is
None
:
events
.
attn_dp_rank
=
self
.
_dp_rank
self
.
_event_queue
.
put
(
events
)
def
shutdown
(
self
)
->
None
:
...
...
@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
self
.
_replay
.
send_multipart
((
client_id
,
b
""
,
self
.
END_SEQ
,
b
""
))
@
staticmethod
def
offset_endpoint_port
(
endpoint
:
Optional
[
str
],
data_parallel_rank
:
int
)
->
Optional
[
str
]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if
not
endpoint
or
data_parallel_rank
==
0
:
return
endpoint
if
"inproc"
in
endpoint
:
return
f
"
{
endpoint
}
_dp
{
data_parallel_rank
}
"
if
"tcp"
in
endpoint
:
if
endpoint
and
":"
in
endpoint
:
# Get everything after the last colon (the port)
last_colon_idx
=
endpoint
.
rfind
(
":"
)
base_addr
=
endpoint
[:
last_colon_idx
]
base_port
=
int
(
endpoint
[
last_colon_idx
+
1
:])
new_port
=
base_port
+
data_parallel_rank
return
f
"
{
base_addr
}
:
{
new_port
}
"
return
endpoint
raise
ValueError
(
"Invalid endpoint: must contain 'inproc' or 'tcp'"
)
class
KVEventsConfig
(
BaseModel
):
"""Configuration for KV event publishing."""
...
...
@@ -342,7 +397,7 @@ class EventPublisherFactory:
cls
.
_registry
[
name
]
=
ctor
@
classmethod
def
create
(
cls
,
config
:
Optional
[
str
])
->
EventPublisher
:
def
create
(
cls
,
config
:
Optional
[
str
]
,
attn_dp_rank
:
int
=
0
)
->
EventPublisher
:
"""Create publisher from a config mapping."""
if
not
config
:
return
NullEventPublisher
()
...
...
@@ -354,4 +409,4 @@ class EventPublisherFactory:
constructor
=
cls
.
_registry
[
kind
]
except
KeyError
as
exc
:
raise
ValueError
(
f
"Unknown event publisher '
{
kind
}
'"
)
from
exc
return
constructor
(
**
config_dict
)
return
constructor
(
attn_dp_rank
=
attn_dp_rank
,
**
config_dict
)
python/sglang/srt/managers/scheduler.py
View file @
f0f84975
...
...
@@ -571,7 +571,9 @@ class Scheduler(
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
,
self
.
attn_dp_rank
)
def
init_disaggregation
(
self
):
self
.
transfer_backend
=
TransferBackend
(
...
...
@@ -1988,7 +1990,7 @@ class Scheduler(
self
.
cum_spec_accept_length
=
self
.
cum_spec_accept_count
=
0
for
k
,
v
in
server_args_dict
.
items
():
global_server_args_dict
[
k
]
=
v
logger
.
info
(
f
"Global server args updated!
"
f
"
{
global_server_args_dict
=
}
"
)
logger
.
info
(
f
"Global server args updated!
{
global_server_args_dict
=
}
"
)
return
SetInternalStateReqOutput
(
updated
=
True
,
server_args
=
global_server_args_dict
,
...
...
test/srt/test_kv_events.py
View file @
f0f84975
...
...
@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
32
,
"--cuda-graph-max-bs"
,
2
,
"--enable-dp-attention"
,
"--dp-size"
,
1
,
],
)
...
...
@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
_
,
seq_bytes
,
payload
=
sub
.
recv_multipart
()
event_batch
=
decoder
.
decode
(
payload
)
for
event
in
event_batch
.
events
:
print
(
f
" -
{
event
}
"
)
events
.
append
(
event
)
for
expected
in
expected_events
:
...
...
@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
finally
:
kill_process_tree
(
process
.
pid
)
def
test_kv_events_attn_dp
(
self
):
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
# Launch multiple subscribers for different DP ranks
decoder
=
Decoder
(
type
=
KVEventBatch
)
context
=
zmq
.
Context
()
# Subscribe to both DP rank endpoints
sub_dp0
=
context
.
socket
(
zmq
.
SUB
)
sub_dp0
.
connect
(
"tcp://localhost:5557"
)
# DP rank 0
topic
=
"kv-events"
sub_dp0
.
setsockopt_string
(
zmq
.
SUBSCRIBE
,
topic
)
sub_dp1
=
context
.
socket
(
zmq
.
SUB
)
sub_dp1
.
connect
(
"tcp://localhost:5558"
)
# DP rank 1 (offset by rank)
sub_dp1
.
setsockopt_string
(
zmq
.
SUBSCRIBE
,
topic
)
# Launch sglang server with DP attention enabled
process
=
popen_launch_server
(
"silence09/DeepSeek-R1-Small-2layers"
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--kv-events-config"
,
'{"publisher": "zmq", "topic": "kv-events"}'
,
"--max-total-tokens"
,
64
,
"--cuda-graph-max-bs"
,
4
,
"--enable-dp-attention"
,
"--dp-size"
,
2
,
"--tp-size"
,
2
,
],
)
try
:
# Make requests to generate events
response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/health_generate"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Send multiple requests to trigger events from both DP ranks
for
i
in
range
(
4
):
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
f
"Request
{
i
}
: The capital of country
{
i
}
is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
# Collect events from both DP ranks
events_dp0
=
[]
events_dp1
=
[]
start
=
time
.
time
()
max_wait_s
=
10
min_events_per_rank
=
3
# Expect at least a few events from each rank
while
(
time
.
time
()
-
start
)
<
max_wait_s
and
(
len
(
events_dp0
)
<
min_events_per_rank
or
len
(
events_dp1
)
<
min_events_per_rank
):
# Check DP rank 0
if
sub_dp0
.
poll
(
timeout
=
100
):
# 100ms timeout
_
,
seq_bytes
,
payload
=
sub_dp0
.
recv_multipart
()
event_batch
=
decoder
.
decode
(
payload
)
print
(
f
"DP Rank 0 - EventBatch: ts=
{
event_batch
.
ts
}
, attn_dp_rank=
{
event_batch
.
attn_dp_rank
}
"
)
self
.
assertEqual
(
event_batch
.
attn_dp_rank
,
0
,
"DP rank 0 events should have attn_dp_rank=0"
,
)
for
event
in
event_batch
.
events
:
print
(
f
" DP0 -
{
event
}
"
)
events_dp0
.
append
(
event
)
# Check DP rank 1
if
sub_dp1
.
poll
(
timeout
=
100
):
# 100ms timeout
_
,
seq_bytes
,
payload
=
sub_dp1
.
recv_multipart
()
event_batch
=
decoder
.
decode
(
payload
)
print
(
f
"DP Rank 1 - EventBatch: ts=
{
event_batch
.
ts
}
, attn_dp_rank=
{
event_batch
.
attn_dp_rank
}
"
)
self
.
assertEqual
(
event_batch
.
attn_dp_rank
,
1
,
"DP rank 1 events should have attn_dp_rank=1"
,
)
for
event
in
event_batch
.
events
:
print
(
f
" DP1 -
{
event
}
"
)
events_dp1
.
append
(
event
)
# Verify we got events from both DP ranks
print
(
f
"Collected
{
len
(
events_dp0
)
}
events from DP rank 0"
)
print
(
f
"Collected
{
len
(
events_dp1
)
}
events from DP rank 1"
)
self
.
assertGreaterEqual
(
len
(
events_dp0
),
min_events_per_rank
,
f
"Expected at least
{
min_events_per_rank
}
events from DP rank 0"
,
)
self
.
assertGreaterEqual
(
len
(
events_dp1
),
min_events_per_rank
,
f
"Expected at least
{
min_events_per_rank
}
events from DP rank 1"
,
)
# Verify event types are as expected
for
events
in
[
events_dp0
,
events_dp1
]:
for
event
in
events
:
self
.
assertIsInstance
(
event
,
(
BlockStored
,
BlockRemoved
,
AllBlocksCleared
),
f
"Event should be a KV cache event, got
{
type
(
event
)
}
"
,
)
finally
:
sub_dp0
.
close
()
sub_dp1
.
close
()
context
.
term
()
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment