Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7adf245b
Unverified
Commit
7adf245b
authored
May 19, 2025
by
Trevor Morris
Committed by
GitHub
May 19, 2025
Browse files
[Metrics] Add KV events publishing (#6098)
parent
299fd22f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
686 additions
and
1 deletion
+686
-1
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/srt/disaggregation/kv_events.py
python/sglang/srt/disaggregation/kv_events.py
+357
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+18
-1
python/sglang/srt/mem_cache/base_prefix_cache.py
python/sglang/srt/mem_cache/base_prefix_cache.py
+3
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+53
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_kv_events.py
test/srt/test_kv_events.py
+247
-0
No files found.
python/pyproject.toml
View file @
7adf245b
...
...
@@ -25,6 +25,7 @@ runtime_common = [
"interegular"
,
"llguidance>=0.7.11,<0.8.0"
,
"modelscope"
,
"msgspec"
,
"ninja"
,
"orjson"
,
"packaging"
,
...
...
python/sglang/srt/disaggregation/kv_events.py
0 → 100644
View file @
7adf245b
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
KV caching events
"""
import
atexit
import
logging
import
queue
import
threading
import
time
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
itertools
import
count
from
queue
import
Queue
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
msgspec
import
zmq
from
pydantic
import
BaseModel
logger
=
logging
.
getLogger
(
__name__
)
class
EventBatch
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
,
# type: ignore[call-arg]
):
ts
:
float
events
:
list
[
Any
]
class
KVCacheEvent
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
,
# type: ignore[call-arg]
tag
=
True
,
):
"""Base class for all KV cache-related events"""
class
BlockStored
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
parent_block_hash
:
Optional
[
int
]
token_ids
:
list
[
int
]
block_size
:
int
lora_id
:
Optional
[
int
]
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
class
AllBlocksCleared
(
KVCacheEvent
):
pass
class
KVEventBatch
(
EventBatch
):
events
:
list
[
Union
[
BlockStored
,
BlockRemoved
,
AllBlocksCleared
]]
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches."""
@
abstractmethod
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
"""Emit events in order.
Implementations should guarantee at-least-once delivery and
monotonic ordering (e.g., via sequence numbers).
"""
@
abstractmethod
def
shutdown
(
self
)
->
None
:
"""Shutdown the publisher."""
class
NullEventPublisher
(
EventPublisher
):
"""No-op implementation (default when disabled)."""
def
publish
(
self
,
events
)
->
None
:
return
def
shutdown
(
self
)
->
None
:
return
class
ZmqEventPublisher
(
EventPublisher
):
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
Spawns a separate thread to handle publishing from a queue.
Parameters
----------
endpoint:
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
connect.
replay_endpoint:
Optional ROUTER address for replay requests. When given, subscribers can
request missed batches by sending the starting sequence number as an
8-byte big-endian integer.
buffer_steps:
Number of past batches to keep for replay.
hwm:
ZeroMQ high-water-mark for PUB socket.
max_queue_size:
Maximum number of events to buffer in memory.
topic:
Topic to publish events to.
"""
SHUTDOWN_TIMEOUT
:
float
=
1.0
END_SEQ
=
(
-
1
).
to_bytes
(
8
,
"big"
,
signed
=
True
)
def
__init__
(
self
,
endpoint
:
str
=
"tcp://*:5557"
,
replay_endpoint
:
Optional
[
str
]
=
None
,
buffer_steps
:
int
=
10_000
,
hwm
:
int
=
100_000
,
max_queue_size
:
int
=
100_000
,
topic
:
str
=
""
,
)
->
None
:
# Storage
self
.
_event_queue
=
Queue
[
Optional
[
EventBatch
]](
maxsize
=
max_queue_size
)
self
.
_buffer
=
deque
[
tuple
[
int
,
bytes
]](
maxlen
=
buffer_steps
)
# ZMQ sockets
self
.
_ctx
=
zmq
.
Context
.
instance
()
self
.
_pub
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_replay
:
Optional
[
zmq
.
Socket
]
=
None
self
.
_endpoint
=
endpoint
self
.
_replay_endpoint
=
replay_endpoint
self
.
_hwm
=
hwm
self
.
_socket_setup
()
# Payload
self
.
_seq_gen
=
count
()
self
.
_topic_bytes
=
topic
.
encode
(
"utf-8"
)
# Thread
self
.
_running
=
True
logger
.
info
(
"Starting ZMQ publisher thread"
)
self
.
_thread
=
threading
.
Thread
(
target
=
self
.
_publisher_thread
,
daemon
=
True
,
name
=
"zmq-publisher"
)
self
.
_thread
.
start
()
atexit
.
register
(
self
.
shutdown
)
def
publish
(
self
,
events
:
EventBatch
)
->
None
:
if
not
self
.
_running
:
raise
RuntimeError
(
"Publisher is closed"
)
self
.
_event_queue
.
put
(
events
)
def
shutdown
(
self
)
->
None
:
"""Stop the publisher thread and clean up resources."""
self
.
_running
=
False
self
.
_event_queue
.
put_nowait
(
None
)
start
=
time
.
time
()
pending_items
=
True
while
pending_items
and
(
time
.
time
()
-
start
<
self
.
SHUTDOWN_TIMEOUT
):
pending_items
=
not
self
.
_event_queue
.
empty
()
if
pending_items
:
time
.
sleep
(
0.1
)
if
pending_items
:
logger
.
warning
(
"Warning: Queue still has %s items after %s seconds timeout"
,
self
.
_event_queue
.
qsize
(),
self
.
SHUTDOWN_TIMEOUT
,
)
if
self
.
_thread
.
is_alive
():
self
.
_thread
.
join
(
timeout
=
self
.
SHUTDOWN_TIMEOUT
)
# Clean up ZMQ resources
try
:
if
self
.
_pub
is
not
None
:
self
.
_pub
.
close
(
linger
=
0
)
if
self
.
_replay
is
not
None
:
self
.
_replay
.
close
(
linger
=
0
)
finally
:
pass
# Do not terminate context; other sockets may use it
def
_socket_setup
(
self
)
->
None
:
"""Initialize sockets
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
"""
if
self
.
_pub
is
None
:
self
.
_pub
=
self
.
_ctx
.
socket
(
zmq
.
PUB
)
self
.
_pub
.
set_hwm
(
self
.
_hwm
)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if
(
"*"
in
self
.
_endpoint
or
"::"
in
self
.
_endpoint
or
self
.
_endpoint
.
startswith
(
"ipc://"
)
or
self
.
_endpoint
.
startswith
(
"inproc://"
)
):
self
.
_pub
.
bind
(
self
.
_endpoint
)
else
:
self
.
_pub
.
connect
(
self
.
_endpoint
)
# Set up replay socket: use ROUTER
# 1) handles multiple REQ clients (identities)
# 2) lets us send back one request → many replies (streamed events)
# 3) works in our non‑blocking poll loop alongside PUB
if
self
.
_replay_endpoint
is
not
None
:
self
.
_replay
=
self
.
_ctx
.
socket
(
zmq
.
ROUTER
)
self
.
_replay
.
bind
(
self
.
_replay_endpoint
)
def
_publisher_thread
(
self
)
->
None
:
"""Background thread that processes the event queue."""
self
.
_pack
=
msgspec
.
msgpack
.
Encoder
()
assert
self
.
_pub
is
not
None
# narrows type for mypy
while
self
.
_running
or
self
.
_event_queue
.
qsize
()
>
0
:
# --- replay (non-critical) ---------------------------------
if
self
.
_replay
is
not
None
and
self
.
_replay
.
poll
(
0
):
try
:
self
.
_service_replay
()
except
Exception
as
e
:
logger
.
exception
(
"Error in replay: %s"
,
e
)
# --- main queue (critical) ---------------------------------
try
:
event
=
self
.
_event_queue
.
get
(
timeout
=
0.1
)
if
event
is
None
:
break
# Sentinel received, exit thread
except
queue
.
Empty
:
continue
try
:
seq
=
next
(
self
.
_seq_gen
)
payload
=
self
.
_pack
.
encode
(
event
)
seq_bytes
=
seq
.
to_bytes
(
8
,
"big"
)
self
.
_pub
.
send_multipart
((
self
.
_topic_bytes
,
seq_bytes
,
payload
))
self
.
_buffer
.
append
((
seq
,
payload
))
self
.
_event_queue
.
task_done
()
except
Exception
as
e
:
# Publishing failed; back-off a bit to avoid a tight error loop
logger
.
exception
(
"Error in publisher thread: %s"
,
e
)
time
.
sleep
(
0.1
)
def
_service_replay
(
self
)
->
None
:
"""If a replay request is waiting, send buffered batches."""
assert
self
.
_replay
is
not
None
# narrows type for mypy
frame
=
self
.
_replay
.
recv_multipart
()
if
len
(
frame
)
!=
3
:
logger
.
warning
(
"Invalid replay request: %s"
,
frame
)
return
client_id
,
_
,
start_seq_bytes
=
frame
start_seq
=
int
.
from_bytes
(
start_seq_bytes
,
"big"
)
for
seq
,
buf
in
self
.
_buffer
:
if
seq
>=
start_seq
:
# [identity, empty_delim, seq_bytes, payload]
# (identity, empty_delim) are stripped off by the router
# receiving payload is (seq_bytes, payload)
self
.
_replay
.
send_multipart
(
(
client_id
,
b
""
,
seq
.
to_bytes
(
8
,
"big"
),
buf
)
)
# Send end of sequence marker
# receiving payload is (-1, b""")
self
.
_replay
.
send_multipart
((
client_id
,
b
""
,
self
.
END_SEQ
,
b
""
))
class
KVEventsConfig
(
BaseModel
):
"""Configuration for KV event publishing."""
publisher
:
str
=
"null"
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
endpoint
:
str
=
"tcp://*:5557"
"""The zmq endpoint to use for publishing kv events.
"""
replay_endpoint
:
Optional
[
str
]
=
None
"""The zmq endpoint to use for replaying kv events.
"""
buffer_steps
:
int
=
10_000
"""The number of steps to cache for replay endpoint. Will only save
events from the last N steps for the replay endpoint.
"""
hwm
:
int
=
100_000
"""The zmq high water mark for the event publisher. After queueing N events,
events will start dropping if the consumer is not keeping up.
"""
max_queue_size
:
int
=
100_000
"""The maximum number of events to queue while waiting for publishing.
"""
topic
:
str
=
""
"""The topic to use for the event publisher. Consumers can subscribe to
this topic to receive events.
"""
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"KVEventsConfig"
:
"""Parse the CLI value for the event publisher config."""
return
KVEventsConfig
.
model_validate_json
(
cli_value
)
class
EventPublisherFactory
:
_registry
:
dict
[
str
,
Callable
[...,
EventPublisher
]]
=
{
"null"
:
NullEventPublisher
,
"zmq"
:
ZmqEventPublisher
,
}
@
classmethod
def
register_publisher
(
cls
,
name
:
str
,
ctor
:
Callable
[...,
EventPublisher
])
->
None
:
if
name
in
cls
.
_registry
:
raise
KeyError
(
f
"publisher '
{
name
}
' already registered"
)
cls
.
_registry
[
name
]
=
ctor
@
classmethod
def
create
(
cls
,
config
:
Optional
[
str
])
->
EventPublisher
:
"""Create publisher from a config mapping."""
if
not
config
:
return
NullEventPublisher
()
config
=
KVEventsConfig
.
from_cli
(
config
)
config_dict
=
config
.
model_dump
()
kind
=
config_dict
.
pop
(
"publisher"
,
"null"
)
try
:
constructor
=
cls
.
_registry
[
kind
]
except
KeyError
as
exc
:
raise
ValueError
(
f
"Unknown event publisher '
{
kind
}
'"
)
from
exc
return
constructor
(
**
config_dict
)
python/sglang/srt/managers/scheduler.py
View file @
7adf245b
...
...
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue
,
SchedulerDisaggregationDecodeMixin
,
)
from
sglang.srt.disaggregation.kv_events
import
EventPublisherFactory
,
KVEventBatch
from
sglang.srt.disaggregation.prefill
import
(
PrefillBootstrapQueue
,
SchedulerDisaggregationPrefillMixin
,
...
...
@@ -197,6 +198,7 @@ class Scheduler(
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_kv_cache_events
=
server_args
.
kv_events_config
is
not
None
self
.
stream_interval
=
server_args
.
stream_interval
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
...
...
@@ -204,7 +206,6 @@ class Scheduler(
self
.
gpu_id
=
gpu_id
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
page_size
=
server_args
.
page_size
# Distributed rank info
self
.
dp_size
=
server_args
.
dp_size
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
...
...
@@ -422,6 +423,7 @@ class Scheduler(
# Init metrics stats
self
.
init_metrics
()
self
.
init_kv_events
(
server_args
.
kv_events_config
)
# Init request dispatcher
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
...
...
@@ -515,6 +517,7 @@ class Scheduler(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
page_size
=
self
.
page_size
,
disable
=
server_args
.
disable_radix_cache
,
enable_kv_cache_events
=
self
.
enable_kv_cache_events
,
)
self
.
decode_mem_cache_buf_multiplier
=
(
...
...
@@ -547,6 +550,10 @@ class Scheduler(
},
)
def
init_kv_events
(
self
,
kv_events_config
:
Optional
[
str
]):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
kv_events_config
)
def
init_disaggregation
(
self
):
self
.
transfer_backend
=
TransferBackend
(
self
.
server_args
.
disaggregation_transfer_backend
...
...
@@ -1154,6 +1161,7 @@ class Scheduler(
self
.
stats
.
avg_request_queue_latency
=
total_queue_latency
/
num_new_seq
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_publish_kv_events
()
def
log_decode_stats
(
self
,
can_run_cuda_graph
:
bool
,
running_batch
:
ScheduleBatch
=
None
...
...
@@ -1213,6 +1221,7 @@ class Scheduler(
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
stats
.
spec_accept_length
=
spec_accept_length
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_publish_kv_events
()
def
check_memory
(
self
):
available_size
=
(
...
...
@@ -1260,6 +1269,7 @@ class Scheduler(
self
.
stats
.
num_queue_reqs
=
len
(
self
.
waiting_queue
)
self
.
stats
.
num_grammar_queue_reqs
=
len
(
self
.
grammar_queue
)
self
.
metrics_collector
.
log_stats
(
self
.
stats
)
self
.
_publish_kv_events
()
def
get_next_batch_to_run
(
self
)
->
Optional
[
ScheduleBatch
]:
# Merge the prefill batch into the running batch
...
...
@@ -2194,6 +2204,13 @@ class Scheduler(
prefix
+=
f
" PP
{
self
.
pp_rank
}
"
return
prefix
def
_publish_kv_events
(
self
):
if
self
.
enable_kv_cache_events
:
events
=
self
.
tree_cache
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
def
is_health_check_generate_req
(
recv_req
):
return
getattr
(
recv_req
,
"rid"
,
""
).
startswith
(
"HEALTH_CHECK"
)
...
...
python/sglang/srt/mem_cache/base_prefix_cache.py
View file @
7adf245b
...
...
@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
def
pretty_print
(
self
):
raise
NotImplementedError
()
def
take_events
(
self
):
return
[]
python/sglang/srt/mem_cache/radix_cache.py
View file @
7adf245b
...
...
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import
torch
from
sglang.srt.disaggregation.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
...
...
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
page_size
:
int
,
disable
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
self
.
page_size
=
page_size
self
.
disable
=
disable
self
.
enable_kv_cache_events
=
enable_kv_cache_events
self
.
kv_event_queue
=
[]
if
self
.
token_to_kv_pool_allocator
:
self
.
device
=
self
.
token_to_kv_pool_allocator
.
device
...
...
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
self
.
protected_size_
=
0
self
.
_record_all_cleared_event
()
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Find the matching prefix from the radix tree.
...
...
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
if
len
(
x
.
parent
.
children
)
==
0
:
heapq
.
heappush
(
leaves
,
x
.
parent
)
self
.
_record_remove_event
(
x
)
def
inc_lock_ref
(
self
,
node
:
TreeNode
):
if
self
.
disable
:
return
0
...
...
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# new_node -> child
self
.
_record_remove_event
(
child
)
new_node
=
TreeNode
()
new_node
.
children
=
{
self
.
get_child_key_fn
(
key
[
split_len
:]):
child
}
new_node
.
parent
=
child
.
parent
...
...
@@ -358,6 +371,10 @@ class RadixCache(BasePrefixCache):
child
.
key
=
child
.
key
[
split_len
:]
child
.
value
=
child
.
value
[
split_len
:]
new_node
.
parent
.
children
[
self
.
get_child_key_fn
(
key
)]
=
new_node
self
.
_record_store_event
(
new_node
)
self
.
_record_store_event
(
child
)
return
new_node
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
...
...
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
new_node
.
value
=
value
node
.
children
[
child_key
]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
self
.
_record_store_event
(
new_node
)
return
total_prefix_length
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
...
...
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
return
ret_list
def
_record_store_event
(
self
,
node
:
TreeNode
):
if
self
.
enable_kv_cache_events
:
block_hash
=
hash
(
tuple
(
node
.
key
))
parent_block_hash
=
hash
(
tuple
(
node
.
parent
.
key
))
self
.
kv_event_queue
.
append
(
BlockStored
(
block_hashes
=
[
block_hash
],
parent_block_hash
=
parent_block_hash
,
token_ids
=
node
.
key
,
block_size
=
len
(
node
.
key
),
lora_id
=
None
,
)
)
def
_record_remove_event
(
self
,
node
:
TreeNode
):
if
self
.
enable_kv_cache_events
:
block_hash
=
hash
(
tuple
(
node
.
key
))
self
.
kv_event_queue
.
append
(
BlockRemoved
(
block_hashes
=
[
block_hash
]))
def
_record_all_cleared_event
(
self
):
if
self
.
enable_kv_cache_events
:
self
.
kv_event_queue
.
append
(
AllBlocksCleared
())
def
take_events
(
self
):
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if
not
self
.
enable_kv_cache_events
:
return
[]
events
=
self
.
kv_event_queue
self
.
kv_event_queue
=
[]
return
events
if
__name__
==
"__main__"
:
tree
=
RadixCache
(
None
,
None
,
page_size
=
1
,
disable
=
False
)
...
...
python/sglang/srt/server_args.py
View file @
7adf245b
...
...
@@ -103,6 +103,7 @@ class ServerArgs:
collect_tokens_histogram
:
bool
=
False
decode_log_interval
:
int
=
40
enable_request_time_stats_logging
:
bool
=
False
kv_events_config
:
Optional
[
str
]
=
None
# API related
api_key
:
Optional
[
str
]
=
None
...
...
@@ -814,6 +815,12 @@ class ServerArgs:
default
=
ServerArgs
.
collect_tokens_histogram
,
help
=
"Collect prompt/generation tokens histogram."
,
)
parser
.
add_argument
(
"--kv-events-config"
,
type
=
str
,
default
=
None
,
help
=
"Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used."
,
)
parser
.
add_argument
(
"--decode-log-interval"
,
type
=
int
,
...
...
test/srt/test_kv_events.py
0 → 100644
View file @
7adf245b
import
time
import
unittest
import
msgspec
import
requests
import
zmq
from
msgspec.msgpack
import
Decoder
from
sglang.srt.disaggregation.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
EventBatch
,
KVCacheEvent
,
KVEventBatch
,
)
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestKvEvents
(
CustomTestCase
):
def
test_kv_events_enabled
(
self
):
"""Test that kv events are sent and received by subscriber data when enabled"""
# Launch kv events subscriber
decoder
=
Decoder
(
type
=
KVEventBatch
)
context
=
zmq
.
Context
()
sub
=
context
.
socket
(
zmq
.
SUB
)
sub
.
connect
(
"tcp://localhost:5557"
)
topic
=
"kv-events"
sub
.
setsockopt_string
(
zmq
.
SUBSCRIBE
,
topic
)
# Launch sglang server
process
=
popen_launch_server
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--kv-events-config"
,
'{"publisher": "zmq", "topic": "kv-events"}'
,
"--max-total-tokens"
,
32
,
"--cuda-graph-max-bs"
,
2
,
],
)
try
:
# Make some requests to generate some metrics
response
=
requests
.
get
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/health_generate"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
)
response
=
requests
.
post
(
f
"
{
DEFAULT_URL_FOR_TEST
}
/generate"
,
json
=
{
"text"
:
"The capital of Spain is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
)
# Expected events. These may be dependent on model used (meta-llama/Llama-3.2-1B-Instruct)
expected_events
=
[
# <begin> The capital city of France is
BlockStored
(
block_hashes
=
[
-
6650323075460941099
],
parent_block_hash
=
5740354900026072187
,
token_ids
=
[
128000
,
791
,
6864
,
3363
,
315
,
9822
,
374
],
block_size
=
7
,
lora_id
=
None
,
),
# Paris. The Eiffel Tower
BlockStored
(
block_hashes
=
[
-
7584018293207282755
],
parent_block_hash
=-
6650323075460941099
,
token_ids
=
[
12366
,
13
,
578
,
469
,
3168
,
301
,
22703
],
block_size
=
7
,
lora_id
=
None
,
),
BlockStored
(
block_hashes
=
[
-
8753497827991233192
],
parent_block_hash
=
5740354900026072187
,
token_ids
=
[
0
],
block_size
=
1
,
lora_id
=
None
,
),
BlockRemoved
(
block_hashes
=
[
-
6650323075460941099
]),
# <begin> The capital
BlockStored
(
block_hashes
=
[
-
2697055055087824455
],
parent_block_hash
=
5740354900026072187
,
token_ids
=
[
128000
,
791
,
6864
],
block_size
=
3
,
lora_id
=
None
,
),
# city of France is
BlockStored
(
block_hashes
=
[
-
7505627135785778022
],
parent_block_hash
=-
2697055055087824455
,
token_ids
=
[
3363
,
315
,
9822
,
374
],
block_size
=
4
,
lora_id
=
None
,
),
# of France is
BlockStored
(
block_hashes
=
[
-
3861108700662737012
],
parent_block_hash
=-
2697055055087824455
,
token_ids
=
[
315
,
9822
,
374
],
block_size
=
3
,
lora_id
=
None
,
),
BlockRemoved
(
block_hashes
=
[
-
7584018293207282755
]),
BlockRemoved
(
block_hashes
=
[
-
8753497827991233192
]),
BlockRemoved
(
block_hashes
=
[
-
7505627135785778022
]),
# Paris. The Eiffel Tower is located in Paris. The Eiffel Tower is a famous landmark in Paris
BlockStored
(
block_hashes
=
[
-
3064341286825792715
],
parent_block_hash
=-
3861108700662737012
,
token_ids
=
[
12366
,
13
,
578
,
469
,
3168
,
301
,
22703
,
374
,
7559
,
304
,
12366
,
13
,
578
,
469
,
3168
,
301
,
22703
,
374
,
264
,
11495
,
38350
,
304
,
12366
,
],
block_size
=
23
,
lora_id
=
None
,
),
BlockRemoved
(
block_hashes
=
[
-
3861108700662737012
]),
# of
BlockStored
(
block_hashes
=
[
6115672085296369592
],
parent_block_hash
=-
2697055055087824455
,
token_ids
=
[
315
],
block_size
=
1
,
lora_id
=
None
,
),
# France is
BlockStored
(
block_hashes
=
[
4208810872343132234
],
parent_block_hash
=
6115672085296369592
,
token_ids
=
[
9822
,
374
],
block_size
=
2
,
lora_id
=
None
,
),
# Spain is
BlockStored
(
block_hashes
=
[
1675819893649989955
],
parent_block_hash
=
6115672085296369592
,
token_ids
=
[
18157
,
374
],
block_size
=
2
,
lora_id
=
None
,
),
BlockRemoved
(
block_hashes
=
[
-
3064341286825792715
]),
# Madrid. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid.
BlockStored
(
block_hashes
=
[
-
8505834929190027295
],
parent_block_hash
=
1675819893649989955
,
token_ids
=
[
25048
,
13
,
578
,
6864
,
315
,
9822
,
374
,
12366
,
13
,
578
,
6864
,
315
,
15704
,
374
,
22463
,
13
,
578
,
6864
,
315
,
18157
,
374
,
25048
,
13
,
],
block_size
=
23
,
lora_id
=
None
,
),
]
# Get events
events
=
[]
start
=
time
.
time
()
max_wait_s
=
5
while
(
len
(
events
)
<
len
(
expected_events
)
and
(
time
.
time
()
-
start
)
<
max_wait_s
):
_
,
seq_bytes
,
payload
=
sub
.
recv_multipart
()
event_batch
=
decoder
.
decode
(
payload
)
for
event
in
event_batch
.
events
:
print
(
f
" -
{
event
}
"
)
events
.
append
(
event
)
for
expected
in
expected_events
:
self
.
assertIn
(
expected
,
events
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment