Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
3f84cdad
Commit
3f84cdad
authored
Mar 11, 2025
by
Alec
Committed by
GitHub
Mar 11, 2025
Browse files
feat: add new metrics and simple router cost fn (#88)
parent
2153ee81
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
347 additions
and
332 deletions
+347
-332
components/metrics/src/bin/mock_worker.rs
components/metrics/src/bin/mock_worker.rs
+6
-0
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
+51
-51
examples/python_rs/llm/vllm/README.md
examples/python_rs/llm/vllm/README.md
+1
-1
examples/python_rs/llm/vllm/kv_router/router.py
examples/python_rs/llm/vllm/kv_router/router.py
+121
-120
examples/python_rs/llm/vllm/kv_router/worker.py
examples/python_rs/llm/vllm/kv_router/worker.py
+7
-4
examples/python_rs/llm/vllm/scripts/kv-router-run.sh
examples/python_rs/llm/vllm/scripts/kv-router-run.sh
+51
-34
examples/rust/llmctl/src/main.rs
examples/rust/llmctl/src/main.rs
+4
-0
lib/bindings/python/rust/llm/kv.rs
lib/bindings/python/rust/llm/kv.rs
+20
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+12
-2
lib/bindings/python/src/dynamo/llm/__init__.py
lib/bindings/python/src/dynamo/llm/__init__.py
+2
-0
lib/bindings/python/tests/test_kv_bindings.py
lib/bindings/python/tests/test_kv_bindings.py
+6
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+3
-72
lib/llm/src/kv_router/indexer.rs
lib/llm/src/kv_router/indexer.rs
+8
-0
lib/llm/src/kv_router/metrics_aggregator.rs
lib/llm/src/kv_router/metrics_aggregator.rs
+49
-48
lib/llm/src/kv_router/protocols.rs
lib/llm/src/kv_router/protocols.rs
+6
-0
No files found.
components/metrics/src/bin/mock_worker.rs
View file @
3f84cdad
...
@@ -112,11 +112,17 @@ fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
...
@@ -112,11 +112,17 @@ fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
let
request_active_slots
=
rand
::
thread_rng
()
.gen_range
(
0
..=
request_total_slots
);
let
request_active_slots
=
rand
::
thread_rng
()
.gen_range
(
0
..=
request_total_slots
);
let
kv_total_blocks
=
100
;
let
kv_total_blocks
=
100
;
let
kv_active_blocks
=
rand
::
thread_rng
()
.gen_range
(
0
..=
kv_total_blocks
);
let
kv_active_blocks
=
rand
::
thread_rng
()
.gen_range
(
0
..=
kv_total_blocks
);
let
num_requests_waiting
=
rand
::
thread_rng
()
.gen_range
(
0
..=
100
);
let
gpu_cache_usage_perc
=
rand
::
thread_rng
()
.gen_range
(
0.0
..=
1.0
);
let
gpu_prefix_cache_hit_rate
=
rand
::
thread_rng
()
.gen_range
(
0.0
..=
1.0
);
let
stats
=
ForwardPassMetrics
{
let
stats
=
ForwardPassMetrics
{
request_active_slots
,
request_active_slots
,
request_total_slots
,
request_total_slots
,
kv_active_blocks
,
kv_active_blocks
,
kv_total_blocks
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
};
};
println!
(
"stats out: {:?}"
,
stats
);
println!
(
"stats out: {:?}"
,
stats
);
serde_json
::
to_value
(
stats
)
.unwrap
()
serde_json
::
to_value
(
stats
)
.unwrap
()
...
...
container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch
View file @
3f84cdad
...
@@ -2990,7 +2990,7 @@ index 3cf1850e..ae006579 100644
...
@@ -2990,7 +2990,7 @@ index 3cf1850e..ae006579 100644
+ gpu_cache_usage_perc: float
+ gpu_cache_usage_perc: float
+ gpu_prefix_cache_hit_rate: float
+ gpu_prefix_cache_hit_rate: float
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..
fe71964
2 100644
index 85b5f31e..
0503029
2 100644
--- a/vllm/engine/multiprocessing/client.py
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
@@ -8,6 +8,7 @@
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...
@@ -3084,7 +3084,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3084,7 +3084,7 @@ index 85b5f31e..fe719642 100644
@staticmethod
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
# Pipeline parallel not yet supported
@@ -180,6 +210,6
3
@@
class MQLLMEngineClient(EngineClient):
@@ -180,6 +210,6
1
@@
class MQLLMEngineClient(EngineClient):
except Exception as e:
except Exception as e:
self._set_errored(e)
self._set_errored(e)
...
@@ -3118,8 +3118,9 @@ index 85b5f31e..fe719642 100644
...
@@ -3118,8 +3118,9 @@ index 85b5f31e..fe719642 100644
+ # Metrics received- check the message
+ # Metrics received- check the message
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ message: Frame = await self.metrics_socket.recv(copy=False)
+ metrics = pickle.loads(message.buffer)
+ metrics = pickle.loads(message.buffer)
+ if self.metrics_publisher is not None:
+ if self.metrics_publisher is not None and isinstance(
+ if isinstance(metrics, KvMetrics):
+ metrics, KvMetrics
+ ):
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ self.metrics_publisher.publish(metrics.request_active_slots,
+ metrics.request_total_slots,
+ metrics.request_total_slots,
+ metrics.kv_active_blocks,
+ metrics.kv_active_blocks,
...
@@ -3127,13 +3128,10 @@ index 85b5f31e..fe719642 100644
...
@@ -3127,13 +3128,10 @@ index 85b5f31e..fe719642 100644
+ metrics.num_requests_waiting,
+ metrics.num_requests_waiting,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_cache_usage_perc,
+ metrics.gpu_prefix_cache_hit_rate)
+ metrics.gpu_prefix_cache_hit_rate)
+ if isinstance(metrics, Stats):
+ # TODO
+ # Send the whole stats to user
+ pass
+
+ logger.debug("Metrics successful.")
+ logger.debug("Metrics successful.")
+
+
+ # TODO: Investigate sending whole stats object
+
+ except asyncio.CancelledError:
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
+
+
...
@@ -3148,7 +3146,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3148,7 +3146,7 @@ index 85b5f31e..fe719642 100644
async def run_output_handler_loop(self):
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
"""Get RequestOutputs from Engine and stream to Request Queues"""
@@ -278,12 +36
5
,26 @@
class MQLLMEngineClient(EngineClient):
@@ -278,12 +36
3
,26 @@
class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
response = await self._wait_for_server_rpc(socket)
...
@@ -3175,7 +3173,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3175,7 +3173,7 @@ index 85b5f31e..fe719642 100644
def close(self):
def close(self):
"""Destroy the ZeroMQ Context."""
"""Destroy the ZeroMQ Context."""
@@ -293,6 +39
4
,8 @@
class MQLLMEngineClient(EngineClient):
@@ -293,6 +39
2
,8 @@
class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
# Cancel background tasks.
if self.health_loop is not None:
if self.health_loop is not None:
self.health_loop.cancel()
self.health_loop.cancel()
...
@@ -3184,7 +3182,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3184,7 +3182,7 @@ index 85b5f31e..fe719642 100644
if self.output_loop is not None:
if self.output_loop is not None:
self.output_loop.cancel()
self.output_loop.cancel()
@@ -415,6 +51
8
,9 @@
class MQLLMEngineClient(EngineClient):
@@ -415,6 +51
6
,9 @@
class MQLLMEngineClient(EngineClient):
"""
"""
if self._errored_with is not None:
if self._errored_with is not None:
raise self._errored_with
raise self._errored_with
...
@@ -3194,7 +3192,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3194,7 +3192,7 @@ index 85b5f31e..fe719642 100644
@property
@property
def is_running(self) -> bool:
def is_running(self) -> bool:
@@ -473,6 +57
9
,7 @@
class MQLLMEngineClient(EngineClient):
@@ -473,6 +57
7
,7 @@
class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
priority: int = 0,
...
@@ -3202,7 +3200,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3202,7 +3200,7 @@ index 85b5f31e..fe719642 100644
*,
*,
inputs: Optional[PromptType] = None # DEPRECATED
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +60
9
,8 @@
class MQLLMEngineClient(EngineClient):
@@ -502,7 +60
7
,8 @@
class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
lora_request, trace_headers,
...
@@ -3212,7 +3210,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3212,7 +3210,7 @@ index 85b5f31e..fe719642 100644
@overload
@overload
def encode(
def encode(
@@ -586,6 +69
4
,7 @@
class MQLLMEngineClient(EngineClient):
@@ -586,6 +69
2
,7 @@
class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
priority: int = 0,
...
@@ -3220,7 +3218,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3220,7 +3218,7 @@ index 85b5f31e..fe719642 100644
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +73
9
,12 @@
class MQLLMEngineClient(EngineClient):
@@ -630,6 +73
7
,12 @@
class MQLLMEngineClient(EngineClient):
else:
else:
lp_bytes = None
lp_bytes = None
...
@@ -3233,7 +3231,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3233,7 +3231,7 @@ index 85b5f31e..fe719642 100644
request_bytes = pickle.dumps(
request_bytes = pickle.dumps(
RPCProcessRequest(
RPCProcessRequest(
prompt=prompt,
prompt=prompt,
@@ -639,11 +75
4
,11 @@
class MQLLMEngineClient(EngineClient):
@@ -639,11 +75
2
,11 @@
class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
priority=priority,
...
@@ -3247,7 +3245,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3247,7 +3245,7 @@ index 85b5f31e..fe719642 100644
await self.input_socket.send_multipart(parts, copy=False)
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +8
20
,6 @@
class MQLLMEngineClient(EngineClient):
@@ -705,3 +8
18
,6 @@
class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
if isinstance(request_output, BaseException):
raise request_output
raise request_output
...
@@ -3255,7 +3253,7 @@ index 85b5f31e..fe719642 100644
...
@@ -3255,7 +3253,7 @@ index 85b5f31e..fe719642 100644
+ def set_metrics_publisher(self, metrics_publisher):
+ def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher
+ self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..
3204cfb8
100644
index a0dd7958..
c82bc15b
100644
--- a/vllm/engine/multiprocessing/engine.py
--- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py
@@ -3,35 +3,115 @@
@@ -3,35 +3,115 @@
...
@@ -3355,23 +3353,23 @@ index a0dd7958..3204cfb8 100644
...
@@ -3355,23 +3353,23 @@ index a0dd7958..3204cfb8 100644
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+# TODO: Send entire stats object to the client
+# TODO: Send entire stats object to the client
+class StatLogger(StatLoggerBase):
+
#
class StatLogger(StatLoggerBase):
+ def __init__(
+
#
def __init__(
+ self,
+
#
self,
+ metrics_socket
+
#
metrics_socket
+ ):
+
#
):
+ self.metrics_socket = metrics_socket
+
#
self.metrics_socket = metrics_socket
+
+
+ def log(self, stats: Stats) -> None:
+
#
def log(self, stats: Stats) -> None:
+ self._send_metrics(stats)
+
#
self._send_metrics(stats)
+
+
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+
#
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
+ pass
+
#
pass
+
+
+ def _send_metrics(self, stats: Stats):
+
#
def _send_metrics(self, stats: Stats):
+ if not self.metrics_socket.closed:
+
#
if not self.metrics_socket.closed:
+ metrics_bytes = pickle.dumps(stats)
+
#
metrics_bytes = pickle.dumps(stats)
+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
#
self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
+
+
+
+
+
+
...
@@ -3379,7 +3377,7 @@ index a0dd7958..3204cfb8 100644
...
@@ -3379,7 +3377,7 @@ index a0dd7958..3204cfb8 100644
class MQLLMEngine:
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
"""A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +174,3
5
@@
class MQLLMEngine:
@@ -94,12 +174,3
7
@@
class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
...
@@ -3406,16 +3404,18 @@ index a0dd7958..3204cfb8 100644
...
@@ -3406,16 +3404,18 @@ index a0dd7958..3204cfb8 100644
+ self.engine.cache_config.num_gpu_blocks,
+ self.engine.cache_config.num_gpu_blocks,
+ self.metrics_socket
+ self.metrics_socket
+ )
+ )
+ self.general_stat_logger = StatLogger(
+ self.metrics_socket
+ )
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("kv_metrics", self.kv_stat_logger)
+ self.engine.add_logger("general_metrics", self.general_stat_logger)
+
+ # TODO investigate sending whole stats object
+ # self.general_stat_logger = StatLogger(
+ # self.metrics_socket
+ # )
+ # self.engine.add_logger("general_metrics", self.general_stat_logger)
+
+
@property
@property
def dead_error(self) -> BaseException:
def dead_error(self) -> BaseException:
if self._errored_with is not None:
if self._errored_with is not None:
@@ -171,8 +27
4
,17 @@
class MQLLMEngine:
@@ -171,8 +27
6
,17 @@
class MQLLMEngine:
# Handle the query from the Client.
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
tracing_enabled = self.engine.is_tracing_enabled()
...
@@ -3435,7 +3435,7 @@ index a0dd7958..3204cfb8 100644
...
@@ -3435,7 +3435,7 @@ index a0dd7958..3204cfb8 100644
except Exception as e:
except Exception as e:
response = e
response = e
@@ -185,6 +29
7
,7 @@
class MQLLMEngine:
@@ -185,6 +29
9
,7 @@
class MQLLMEngine:
while True:
while True:
if not self.engine.has_unfinished_requests():
if not self.engine.has_unfinished_requests():
...
@@ -3443,7 +3443,7 @@ index a0dd7958..3204cfb8 100644
...
@@ -3443,7 +3443,7 @@ index a0dd7958..3204cfb8 100644
# Poll until there is work to do.
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
# When there's no work, check on engine health and send
@@ -220,6 +33
3
,13 @@
class MQLLMEngine:
@@ -220,6 +33
5
,13 @@
class MQLLMEngine:
def handle_new_input(self):
def handle_new_input(self):
"""Handle new input from the socket"""
"""Handle new input from the socket"""
try:
try:
...
@@ -3457,7 +3457,7 @@ index a0dd7958..3204cfb8 100644
...
@@ -3457,7 +3457,7 @@ index a0dd7958..3204cfb8 100644
while self.input_socket.poll(timeout=0) != 0:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
request = pickle.loads(frames[0].buffer)
@@ -262,6 +38
2
,11 @@
class MQLLMEngine:
@@ -262,6 +38
4
,11 @@
class MQLLMEngine:
self._send_outputs(rpc_err)
self._send_outputs(rpc_err)
try:
try:
...
@@ -3469,7 +3469,7 @@ index a0dd7958..3204cfb8 100644
...
@@ -3469,7 +3469,7 @@ index a0dd7958..3204cfb8 100644
self.engine.add_request(
self.engine.add_request(
request_id=request_id,
request_id=request_id,
prompt=request.prompt,
prompt=request.prompt,
@@ -269,7 +39
4
,9 @@
class MQLLMEngine:
@@ -269,7 +39
6
,9 @@
class MQLLMEngine:
lora_request=request.lora_request,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
prompt_adapter_request=request.prompt_adapter_request,
...
...
examples/python_rs/llm/vllm/README.md
View file @
3f84cdad
...
@@ -237,7 +237,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
...
@@ -237,7 +237,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example:
Example:
```
bash
```
bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8
prefix
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
bash /workspace/examples/python_rs/llm/vllm/scripts/kv-router-run.sh 8
test
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions
# List tmux sessions
tmux
ls
tmux
ls
...
...
examples/python_rs/llm/vllm/kv_router/router.py
View file @
3f84cdad
...
@@ -15,98 +15,138 @@
...
@@ -15,98 +15,138 @@
import
asyncio
import
asyncio
import
random
from
argparse
import
Namespace
from
argparse
import
Namespace
from
enum
import
Enum
from
typing
import
AsyncIterator
from
typing
import
AsyncIterator
import
uvloop
import
uvloop
from
common.protocol
import
Tokens
from
common.protocol
import
Tokens
from
vllm.logger
import
logger
as
vllm_logger
from
vllm.logger
import
logger
as
vllm_logger
from
dynamo.llm
import
KvIndexer
,
KvMetricsAggregator
,
KvRouter
from
dynamo.llm
import
AggregatedMetrics
,
KvIndexer
,
KvMetricsAggregator
,
OverlapScores
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_endpoint
,
dynamo_worker
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_endpoint
,
dynamo_worker
WorkerId
=
str
WorkerId
=
str
class
RoutingStrategy
(
Enum
):
class
CustomRouter
:
PREFIX
=
"prefix"
ROUND_ROBIN
=
"round_robin"
RANDOM
=
"random"
class
Router
:
"""
"""
Request handler for the generate endpoint
Request handler for the generate endpoint
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
router
:
KvRouter
,
workers_client
,
routing_strategy
:
RoutingStrategy
=
RoutingStrategy
.
PREFIX
,
indexer
:
KvIndexer
,
metrics_aggregator
:
KvMetricsAggregator
,
):
):
vllm_logger
.
info
(
vllm_logger
.
info
(
"Initializing Custom Router"
)
f
"Initializing KV Router with strategy:
{
routing_strategy
.
value
}
"
self
.
indexer
=
indexer
self
.
metrics_aggregator
=
metrics_aggregator
self
.
workers_client
=
workers_client
def
_cost_function
(
self
,
scores
:
OverlapScores
|
None
,
metrics
:
AggregatedMetrics
|
None
,
token_length
:
int
,
):
worker_scores
=
{}
if
scores
:
for
worker_id
,
score
in
scores
.
scores
.
items
():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores
[
worker_id
]
=
(
score
*
self
.
indexer
.
block_size
()
/
token_length
)
)
self
.
router
=
router
self
.
routing_strategy
=
routing_strategy
@
dynamo_endpoint
(
Tokens
,
WorkerId
)
worker_metrics
=
{}
async
def
generate
(
self
,
request
)
->
AsyncIterator
[
WorkerId
]:
# pull metrics for each worker
lora_id
=
0
max_waiting
=
0.0
worker_id
=
None
if
metrics
:
if
self
.
routing_strategy
==
RoutingStrategy
.
PREFIX
:
for
endpoint
in
metrics
.
endpoints
:
try
:
worker_id
=
endpoint
.
worker_id
worker_id
=
await
self
.
router
.
schedule
(
request
.
tokens
,
lora_id
)
worker_metrics
[
worker_id
]
=
{
# [NOTE][TODO] Now that the scheduler may return more error messages,
"gpu_cache_usage_perc"
:
endpoint
.
gpu_cache_usage_perc
# now we are catching all exceptions and logging them. Should have
if
hasattr
(
endpoint
,
"gpu_cache_usage_perc"
)
# catch specific router exceptions once we have dedicated types.
else
0.0
,
except
Exception
as
e
:
"num_requests_waiting"
:
endpoint
.
num_requests_waiting
vllm_logger
.
info
(
f
"
{
e
}
"
)
if
hasattr
(
endpoint
,
"num_requests_waiting"
)
worker_id
=
""
else
0.0
,
vllm_logger
.
exception
(
f
"Error during worker selection:
{
e
}
"
)
"gpu_prefix_cache_hit_rate"
:
endpoint
.
gpu_prefix_cache_hit_rate
if
hasattr
(
endpoint
,
"gpu_prefix_cache_hit_rate"
)
else
0.0
,
}
max_waiting
=
max
(
max_waiting
,
worker_metrics
[
worker_id
][
"num_requests_waiting"
]
)
vllm_logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
"
)
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids
=
self
.
workers_client
.
endpoint_ids
()
worker_logits
=
{}
for
worker_id
in
worker_ids
:
# Use default values if worker not in scores or metrics
score
=
worker_scores
.
get
(
worker_id
,
0.0
)
metrics_dict
=
worker_metrics
.
get
(
worker_id
,
{
"gpu_cache_usage_perc"
:
0.0
,
"num_requests_waiting"
:
0.0
,
"gpu_prefix_cache_hit_rate"
:
0.0
,
},
)
yield
str
(
worker_id
)
normalized_waiting
=
(
metrics_dict
[
"num_requests_waiting"
]
/
max_waiting
if
max_waiting
>
0
else
0.0
)
else
:
# Have 1 metric that weights towards cache hit
# TODO: Do we implement round_robin and random here?
# 2 metrics that penalize overloaded worker and queuing
# or just skip this router and directly enable in preprocess?
worker_logits
[
worker_id
]
=
(
raise
NotImplementedError
(
2
*
score
-
metrics_dict
[
"gpu_cache_usage_perc"
]
-
normalized_waiting
f
"Routing strategy
{
self
.
routing_strategy
}
not implemented"
)
vllm_logger
.
info
(
f
"Formula for
{
worker_id
}
:
{
worker_logits
[
worker_id
]:.
3
f
}
= 2.0 *
{
score
:.
3
f
}
-
{
metrics_dict
[
'gpu_cache_usage_perc'
]:.
3
f
}
-
{
normalized_waiting
:.
3
f
}
"
)
)
if
not
worker_logits
or
all
(
logit
==
0
for
logit
in
worker_logits
.
values
()):
return
""
class
CustomRouter
:
# Select the worker with the highest logit
"""
if
worker_logits
:
Request handler for the generate endpoint
max_logit
=
max
(
worker_logits
.
values
())
"""
best_workers
=
[
wid
for
wid
,
logit
in
worker_logits
.
items
()
if
logit
==
max_logit
]
best_worker_id
=
random
.
choice
(
best_workers
)
else
:
best_worker_id
=
""
def
__init__
(
# Log the metrics for the selected worker
self
,
if
best_worker_id
:
indexer
:
KvIndexer
,
vllm_logger
.
info
(
metrics_aggregator
:
KvMetricsAggregator
,
f
"Selected worker:
{
best_worker_id
}
, logit:
{
worker_logits
[
best_worker_id
]:.
3
f
}
"
):
)
self
.
indexer
=
indexer
vllm_logger
.
info
(
self
.
metrics_aggregator
=
metrics_aggregator
f
"Score:
{
scores
.
scores
.
get
(
best_worker_id
,
0.0
)
if
scores
else
0.0
:.
3
f
}
"
)
def
_cost_function
(
self
,
scores
,
metrics
):
metrics_dict
=
worker_metrics
.
get
(
best_worker_id
,
{})
# naive cost function for demonstration purposes
vllm_logger
.
info
(
current_best
=
(
""
,
0
)
f
"GPU Cache Hit Rate:
{
metrics_dict
.
get
(
'gpu_prefix_cache_hit_rate'
,
0.0
):.
3
f
}
"
for
worker_id
,
score
in
scores
.
scores
.
items
():
if
score
>
current_best
[
1
]:
current_best
=
(
worker_id
,
score
)
for
endpoint
in
metrics
.
endpoints
:
if
endpoint
.
worker_id
==
current_best
[
0
]:
print
(
f
"Metrics of endpoint:
{
endpoint
.
worker_id
}
"
)
print
(
f
"request slot usage:
{
endpoint
.
request_active_slots
}
/
{
endpoint
.
request_total_slots
}
"
)
)
print
(
vllm_logger
.
info
(
f
"KV block usage:
{
endpoint
.
kv_active_blocks
}
/
{
endpoint
.
kv_total_blocks
}
"
f
"GPU Cache Usage:
{
metrics_dict
.
get
(
'gpu_cache_usage_perc'
,
0.0
):.
3
f
}
"
)
vllm_logger
.
info
(
f
"Requests Waiting:
{
metrics_dict
.
get
(
'num_requests_waiting'
,
0.0
)
/
max_waiting
if
max_waiting
>
0
else
0.0
:.
3
f
}
"
)
)
return
current_best
[
0
]
return
best_worker_id
@
dynamo_endpoint
(
Tokens
,
WorkerId
)
@
dynamo_endpoint
(
Tokens
,
WorkerId
)
async
def
generate
(
self
,
request
)
->
AsyncIterator
[
WorkerId
]:
async
def
generate
(
self
,
request
)
->
AsyncIterator
[
WorkerId
]:
...
@@ -116,18 +156,16 @@ class CustomRouter:
...
@@ -116,18 +156,16 @@ class CustomRouter:
scores
=
await
self
.
indexer
.
find_matches_for_request
(
scores
=
await
self
.
indexer
.
find_matches_for_request
(
request
.
tokens
,
lora_id
request
.
tokens
,
lora_id
)
)
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
worker_id
=
self
.
_cost_function
(
scores
,
metrics
)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except
Exception
as
e
:
except
Exception
as
e
:
vllm_logger
.
info
(
f
"
{
e
}
"
)
scores
=
{}
worker_id
=
""
vllm_logger
.
exception
(
f
"Error finding matches:
{
e
}
"
)
vllm_logger
.
exception
(
f
"Error during worker selection:
{
e
}
"
)
token_length
=
len
(
request
.
tokens
)
metrics
=
await
self
.
metrics_aggregator
.
get_metrics
()
worker_id
=
self
.
_cost_function
(
scores
,
metrics
,
token_length
)
vllm_logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
"
)
vllm_logger
.
info
(
f
"Scheduling to worker_id:
{
worker_id
}
"
)
vllm_logger
.
info
(
"########"
)
yield
str
(
worker_id
)
yield
str
(
worker_id
)
...
@@ -144,14 +182,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
...
@@ -144,14 +182,6 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
.
endpoint
(
"generate"
)
.
endpoint
(
"generate"
)
.
client
()
.
client
()
)
)
wait_task
=
workers_client
.
wait_for_endpoints
()
await
asyncio
.
sleep
(
1
)
while
not
wait_task
.
done
():
vllm_logger
.
info
(
"Waiting for workers to be ready..."
)
await
asyncio
.
sleep
(
5
)
wait_task
.
result
()
while
len
(
workers_client
.
endpoint_ids
())
<
args
.
min_workers
:
while
len
(
workers_client
.
endpoint_ids
())
<
args
.
min_workers
:
vllm_logger
.
info
(
vllm_logger
.
info
(
...
@@ -172,23 +202,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
...
@@ -172,23 +202,11 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
endpoint
=
router_component
.
endpoint
(
"generate"
)
endpoint
=
router_component
.
endpoint
(
"generate"
)
if
args
.
custom_router
:
# @REVIEWER - I'm not currently checking if block size matches that of the engine
# If they don't match things will silently fail
# The preferred solution would be for the KV Indexer to read from the MDC in etcd and not bother the user at all
# The second solution would be to do KvIndexer(kv_listener, MDC.block_size)
# as this ensures block size matches that of the engine
# In this case we need to do some sort of handshake or check in case a user just puts in a random block size
indexer
=
KvIndexer
(
kv_listener
,
args
.
block_size
)
indexer
=
KvIndexer
(
kv_listener
,
args
.
block_size
)
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
metrics_aggregator
=
KvMetricsAggregator
(
kv_listener
)
await
endpoint
.
serve_endpoint
(
await
endpoint
.
serve_endpoint
(
CustomRouter
(
indexer
,
metrics_aggregator
).
generate
CustomRouter
(
workers_client
,
indexer
,
metrics_aggregator
).
generate
)
)
else
:
# TODO Read block_size from MDC
router
=
KvRouter
(
runtime
,
kv_listener
,
args
.
block_size
)
await
endpoint
.
serve_endpoint
(
Router
(
router
,
args
.
routing_strategy
).
generate
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -197,35 +215,18 @@ if __name__ == "__main__":
...
@@ -197,35 +215,18 @@ if __name__ == "__main__":
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--routing-strategy"
,
type
=
RoutingStrategy
,
default
=
RoutingStrategy
.
PREFIX
,
choices
=
list
(
RoutingStrategy
),
help
=
"Routing strategy to use"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--min-workers"
,
"--min-workers"
,
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
"Minimum number of workers required before proceeding"
,
help
=
"Minimum number of workers required before proceeding"
,
)
)
parser
.
add_argument
(
# TODO: Read block size
"--model-name"
,
type
=
str
,
default
=
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
,
help
=
"Model that is being served"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--block-size"
,
"--block-size"
,
type
=
int
,
type
=
int
,
help
=
"KV block size"
,
default
=
64
,
)
help
=
"Block size for the KV Indexer"
,
parser
.
add_argument
(
"--custom-router"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use custom router or not"
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
examples/python_rs/llm/vllm/kv_router/worker.py
View file @
3f84cdad
...
@@ -99,10 +99,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
...
@@ -99,10 +99,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# Initially send dummy metrics to kick start,
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
# vLLM will not update stat until forward pass is triggered
metrics_publisher
.
publish
(
metrics_publisher
.
publish
(
0
,
0
,
# request_active_slots
1024
,
1024
,
# request_total_slots
0
,
0
,
# kv_active_blocks
1024
,
1024
,
# kv_total_blocks
0
,
# num_requests_waiting
0.0
,
# gpu_cache_usage_perc
0.0
,
# gpu_prefix_cache_hit_rate
)
)
await
asyncio
.
gather
(
await
asyncio
.
gather
(
...
...
examples/python_rs/llm/vllm/scripts/kv-router-run.sh
View file @
3f84cdad
...
@@ -18,53 +18,78 @@
...
@@ -18,53 +18,78 @@
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single node
# - Must use a single node
if
[
$#
-lt
2
]
;
then
if
[
$#
-lt
3
]
;
then
echo
"Usage:
$0
<number_of_workers> <
routing_strategy> [model_name] [
endpoint_name]"
echo
"Usage:
$0
<number_of_workers> <
log_dir_name> [model_name] [model_args] [chat_endpoint_name] [completions_
endpoint_name]"
echo
"Error: Must specify at least number of workers
and routing strategy
"
echo
"Error: Must specify at least number of workers
, log_dir_name
"
echo
"Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo
"Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo
"Optional: endpoint_name (default: dynamo.process.chat/completions)"
echo
"Optional: model_args (quoted string with model arguments)"
echo
"Optional: chat_endpoint_name (default: dynamo.process.chat/completions)"
echo
"Optional: completions_endpoint_name (default: dynamo.process.completions)"
exit
1
exit
1
fi
fi
# Uncomment if using Cache
# export HF_HUB_OFFLINE=1
# https://github.com/vllm-project/vllm/issues/10734#issuecomment-2507201353
# Fix for:torch.distributed.DistBackendError: File name too long
# export GLOO_SOCKET_IFNAME=lo
NUM_WORKERS
=
$1
NUM_WORKERS
=
$1
ROUTING_STRATEGY
=
$2
LOG_DIR_NAME
=
$2
MODEL_NAME
=
${
3
:-
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
}
MODEL_NAME
=
${
3
:-
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
}
ENDPOINT_NAME
=
${
4
:-
"dynamo.process.chat/completions"
}
CUSTOM_MODEL_ARGS
=
$4
VALID_STRATEGIES
=(
"prefix"
)
CHAT_ENDPOINT_NAME
=
${
5
:-
"dynamo.process.chat/completions"
}
COMPLETIONS_ENDPOINT_NAME
=
${
6
:-
"dynamo.process.completions"
}
SESSION_NAME
=
"v"
SESSION_NAME
=
"v"
WORKDIR
=
"/workspace/examples/python_rs/llm/vllm"
WORKDIR
=
"/workspace/examples/python_rs/llm/vllm"
INIT_CMD
=
"cd
$WORKDIR
"
INIT_CMD
=
"cd
$WORKDIR
"
if
[[
!
"
${
VALID_STRATEGIES
[@]
}
"
=
~
"
${
ROUTING_STRATEGY
}
"
]]
;
then
echo
"Error: Invalid routing strategy. Must be one of:
${
VALID_STRATEGIES
[*]
}
"
# Default model args
exit
1
DEFAULT_MODEL_ARGS
=
"--model
$MODEL_NAME
\
--tokenizer
$MODEL_NAME
\
--enable-prefix-caching
\
--block-size 64"
# Use custom model args if provided, otherwise use default
if
[
-n
"
$CUSTOM_MODEL_ARGS
"
]
;
then
MODEL_ARGS
=
"
$CUSTOM_MODEL_ARGS
"
echo
"Using custom model arguments"
else
MODEL_ARGS
=
"
$DEFAULT_MODEL_ARGS
"
echo
"Using default model arguments"
fi
fi
# Create logs directory if it doesn't exist
LOGS_DIR
=
"/logs/
$LOG_DIR_NAME
"
mkdir
-p
$LOGS_DIR
chmod
-R
775
$LOGS_DIR
########################################################
########################################################
# HTTP Server
# HTTP Server
########################################################
########################################################
HTTP_CMD
=
"DYN_LOG=DEBUG http"
HTTP_CMD
=
"DYN_LOG=DEBUG http
|& tee
$LOGS_DIR
/http.log
"
tmux new-session
-d
-s
"
$SESSION_NAME
-http"
tmux new-session
-d
-s
"
$SESSION_NAME
-http"
tmux send-keys
-t
"
$SESSION_NAME
-http"
"
$INIT_CMD
&&
$HTTP_CMD
"
C-m
tmux send-keys
-t
"
$SESSION_NAME
-http"
"
$INIT_CMD
&&
$HTTP_CMD
"
C-m
########################################################
########################################################
# LLMCTL
# LLMCTL
########################################################
########################################################
LLMCTL_CMD
=
"sleep 5 && llmctl http remove chat-model
$MODEL_NAME
&&
\
LLMCTL_CMD
=
"sleep 5 &&
\
llmctl http add chat-model
$MODEL_NAME
$ENDPOINT_NAME
&&
\
llmctl http remove chat
$MODEL_NAME
&&
\
llmctl http list chat-model"
llmctl http remove completions
$MODEL_NAME
&&
\
llmctl http add chat
$MODEL_NAME
$CHAT_ENDPOINT_NAME
&&
\
llmctl http add completions
$MODEL_NAME
$COMPLETIONS_ENDPOINT_NAME
&&
\
llmctl http list |& tee
$LOGS_DIR
/llmctl.log"
tmux new-session
-d
-s
"
$SESSION_NAME
-llmctl"
tmux new-session
-d
-s
"
$SESSION_NAME
-llmctl"
tmux send-keys
-t
"
$SESSION_NAME
-llmctl"
"
$INIT_CMD
&&
$LLMCTL_CMD
"
C-m
tmux send-keys
-t
"
$SESSION_NAME
-llmctl"
"
$INIT_CMD
&&
$LLMCTL_CMD
"
C-m
########################################################
########################################################
# Processor
# Processor
########################################################
########################################################
# For now processor gets same args as worker, need to have them communicate over etcd
PROCESSOR_CMD
=
"RUST_LOG=info python3 -m kv_router.processor
$MODEL_ARGS
|& tee
$LOGS_DIR
/processor.log"
PROCESSOR_CMD
=
"RUST_LOG=info python3 -m kv_router.processor
\
--model
$MODEL_NAME
\
--tokenizer
$MODEL_NAME
\
--enable-prefix-caching
\
--block-size 32
\
--max-model-len 16384 "
tmux new-session
-d
-s
"
$SESSION_NAME
-processor"
tmux new-session
-d
-s
"
$SESSION_NAME
-processor"
tmux send-keys
-t
"
$SESSION_NAME
-processor"
"
$INIT_CMD
&&
$PROCESSOR_CMD
"
C-m
tmux send-keys
-t
"
$SESSION_NAME
-processor"
"
$INIT_CMD
&&
$PROCESSOR_CMD
"
C-m
...
@@ -72,10 +97,7 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
...
@@ -72,10 +97,7 @@ tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
# Router
# Router
########################################################
########################################################
ROUTER_CMD
=
"RUST_LOG=info python3 -m kv_router.router
\
ROUTER_CMD
=
"RUST_LOG=info python3 -m kv_router.router
\
--model
$MODEL_NAME
\
--min-workers
$NUM_WORKERS
|& tee
$LOGS_DIR
/router.log"
--routing-strategy
$ROUTING_STRATEGY
\
--min-workers
$NUM_WORKERS
\
--block-size 32"
tmux new-session
-d
-s
"
$SESSION_NAME
-router"
tmux new-session
-d
-s
"
$SESSION_NAME
-router"
tmux send-keys
-t
"
$SESSION_NAME
-router"
"
$INIT_CMD
&&
$ROUTER_CMD
"
C-m
tmux send-keys
-t
"
$SESSION_NAME
-router"
"
$INIT_CMD
&&
$ROUTER_CMD
"
C-m
...
@@ -83,17 +105,12 @@ tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
...
@@ -83,17 +105,12 @@ tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
########################################################
########################################################
# Workers
# Workers
########################################################
########################################################
WORKER_CMD
=
"RUST_LOG=info python3 -m kv_router.worker
\
WORKER_CMD
=
"RUST_LOG=info python3 -m kv_router.worker
$MODEL_ARGS
"
--model
$MODEL_NAME
\
--tokenizer
$MODEL_NAME
\
--enable-prefix-caching
\
--block-size 64
\
--max-model-len 16384 "
for
i
in
$(
seq
1
$NUM_WORKERS
)
;
do
for
i
in
$(
seq
1
$NUM_WORKERS
)
;
do
tmux new-session
-d
-s
"
$SESSION_NAME
-
$i
"
tmux new-session
-d
-s
"
$SESSION_NAME
-
$i
"
done
done
for
i
in
$(
seq
1
$NUM_WORKERS
)
;
do
for
i
in
$(
seq
1
$NUM_WORKERS
)
;
do
tmux send-keys
-t
"
$SESSION_NAME
-
$i
"
"
$INIT_CMD
&& CUDA_VISIBLE_DEVICES=
$((
i-1
))
$WORKER_CMD
"
C-m
tmux send-keys
-t
"
$SESSION_NAME
-
$i
"
"
$INIT_CMD
&& CUDA_VISIBLE_DEVICES=
$((
i-1
))
$WORKER_CMD
|& tee
$LOGS_DIR
/worker-
$i
.log
"
C-m
done
done
\ No newline at end of file
examples/rust/llmctl/src/main.rs
View file @
3f84cdad
...
@@ -228,6 +228,10 @@ async fn add_model(
...
@@ -228,6 +228,10 @@ async fn add_model(
endpoint_name
endpoint_name
);
);
if
model_name
.starts_with
(
'/'
)
{
raise!
(
"Model name '{}' cannot start with a slash"
,
model_name
);
}
let
parts
:
Vec
<&
str
>
=
endpoint_name
.split
(
'.'
)
.collect
();
let
parts
:
Vec
<&
str
>
=
endpoint_name
.split
(
'.'
)
.collect
();
if
parts
.len
()
<
2
{
if
parts
.len
()
<
2
{
...
...
lib/bindings/python/rust/llm/kv.rs
View file @
3f84cdad
...
@@ -90,6 +90,7 @@ impl KvMetricsPublisher {
...
@@ -90,6 +90,7 @@ impl KvMetricsPublisher {
})
})
}
}
#[allow(clippy::too_many_arguments)]
fn
publish
(
fn
publish
(
&
self
,
&
self
,
_
py
:
Python
,
_
py
:
Python
,
...
@@ -97,6 +98,9 @@ impl KvMetricsPublisher {
...
@@ -97,6 +98,9 @@ impl KvMetricsPublisher {
request_total_slots
:
u64
,
request_total_slots
:
u64
,
kv_active_blocks
:
u64
,
kv_active_blocks
:
u64
,
kv_total_blocks
:
u64
,
kv_total_blocks
:
u64
,
num_requests_waiting
:
u64
,
gpu_cache_usage_perc
:
f32
,
gpu_prefix_cache_hit_rate
:
f32
,
)
->
PyResult
<
()
>
{
)
->
PyResult
<
()
>
{
self
.inner
self
.inner
.publish
(
.publish
(
...
@@ -105,6 +109,9 @@ impl KvMetricsPublisher {
...
@@ -105,6 +109,9 @@ impl KvMetricsPublisher {
request_total_slots
,
request_total_slots
,
kv_active_blocks
,
kv_active_blocks
,
kv_total_blocks
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
}
}
.into
(),
.into
(),
)
)
...
@@ -180,6 +187,10 @@ impl KvIndexer {
...
@@ -180,6 +187,10 @@ impl KvIndexer {
})
})
}
}
fn
block_size
(
&
self
)
->
usize
{
self
.inner
.block_size
()
}
fn
find_matches_for_request
<
'p
>
(
fn
find_matches_for_request
<
'p
>
(
&
self
,
&
self
,
py
:
Python
<
'p
>
,
py
:
Python
<
'p
>
,
...
@@ -212,6 +223,12 @@ pub(crate) struct EndpointKvMetrics {
...
@@ -212,6 +223,12 @@ pub(crate) struct EndpointKvMetrics {
pub
kv_active_blocks
:
u64
,
pub
kv_active_blocks
:
u64
,
#[pyo3(get,
set)]
#[pyo3(get,
set)]
pub
kv_total_blocks
:
u64
,
pub
kv_total_blocks
:
u64
,
#[pyo3(get,
set)]
pub
num_requests_waiting
:
u64
,
#[pyo3(get,
set)]
pub
gpu_cache_usage_perc
:
f32
,
#[pyo3(get,
set)]
pub
gpu_prefix_cache_hit_rate
:
f32
,
}
}
#[pyclass]
#[pyclass]
...
@@ -258,6 +275,9 @@ impl KvMetricsAggregator {
...
@@ -258,6 +275,9 @@ impl KvMetricsAggregator {
request_total_slots
:
x
.data.request_total_slots
,
request_total_slots
:
x
.data.request_total_slots
,
kv_active_blocks
:
x
.data.kv_active_blocks
,
kv_active_blocks
:
x
.data.kv_active_blocks
,
kv_total_blocks
:
x
.data.kv_total_blocks
,
kv_total_blocks
:
x
.data.kv_total_blocks
,
num_requests_waiting
:
x
.data.num_requests_waiting
,
gpu_cache_usage_perc
:
x
.data.gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
:
x
.data.gpu_prefix_cache_hit_rate
,
})
})
.collect
();
.collect
();
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
3f84cdad
...
@@ -242,7 +242,11 @@ class KvMetricsPublisher:
...
@@ -242,7 +242,11 @@ class KvMetricsPublisher:
def publish(self, request_active_slots: int,
def publish(self, request_active_slots: int,
request_total_slots: int,
request_total_slots: int,
kv_active_blocks: int,
kv_active_blocks: int,
kv_total_blocks: int) -> None:
kv_total_blocks: int,
num_requests_waiting: int,
gpu_cache_usage_perc: float,
gpu_prefix_cache_hit_rate: float
) -> None:
"""
"""
Update the KV metrics being reported.
Update the KV metrics being reported.
"""
"""
...
@@ -298,7 +302,7 @@ class KvIndexer:
...
@@ -298,7 +302,7 @@ class KvIndexer:
...
...
def __init__(self, component: Component) -> None:
def __init__(self, component: Component
, block_size: int
) -> None:
"""
"""
Create a `KvIndexer` object
Create a `KvIndexer` object
"""
"""
...
@@ -309,6 +313,12 @@ class KvIndexer:
...
@@ -309,6 +313,12 @@ class KvIndexer:
"""
"""
...
...
def block_size(self) -> int:
"""
Return the block size of the KV Indexer.
"""
...
class AggregatedMetrics:
class AggregatedMetrics:
"""
"""
A collection of metrics of the endpoints
A collection of metrics of the endpoints
...
...
lib/bindings/python/src/dynamo/llm/__init__.py
View file @
3f84cdad
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
dynamo._core
import
AggregatedMetrics
as
AggregatedMetrics
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
DisaggregatedRouter
as
DisaggregatedRouter
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpAsyncEngine
as
HttpAsyncEngine
from
dynamo._core
import
HttpError
as
HttpError
from
dynamo._core
import
HttpError
as
HttpError
...
@@ -21,3 +22,4 @@ from dynamo._core import KvIndexer as KvIndexer
...
@@ -21,3 +22,4 @@ from dynamo._core import KvIndexer as KvIndexer
from
dynamo._core
import
KvMetricsAggregator
as
KvMetricsAggregator
from
dynamo._core
import
KvMetricsAggregator
as
KvMetricsAggregator
from
dynamo._core
import
KvMetricsPublisher
as
KvMetricsPublisher
from
dynamo._core
import
KvMetricsPublisher
as
KvMetricsPublisher
from
dynamo._core
import
KvRouter
as
KvRouter
from
dynamo._core
import
KvRouter
as
KvRouter
from
dynamo._core
import
OverlapScores
as
OverlapScores
lib/bindings/python/tests/test_kv_bindings.py
View file @
3f84cdad
...
@@ -193,6 +193,9 @@ async def test_metrics_aggregator(distributed_runtime):
...
@@ -193,6 +193,9 @@ async def test_metrics_aggregator(distributed_runtime):
"request_total_slots"
:
1024
,
"request_total_slots"
:
1024
,
"kv_active_blocks"
:
523
,
"kv_active_blocks"
:
523
,
"kv_total_blocks"
:
777
,
"kv_total_blocks"
:
777
,
"num_requests_waiting"
:
10
,
"gpu_cache_usage_perc"
:
0.5
,
"gpu_prefix_cache_hit_rate"
:
0.75
,
}
}
# need 'create_task' to put publisher task in the background
# need 'create_task' to put publisher task in the background
...
@@ -222,5 +225,8 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
...
@@ -222,5 +225,8 @@ async def metrics_publisher_task(kv_listener, expected_metrics):
expected_metrics
[
"request_total_slots"
],
expected_metrics
[
"request_total_slots"
],
expected_metrics
[
"kv_active_blocks"
],
expected_metrics
[
"kv_active_blocks"
],
expected_metrics
[
"kv_total_blocks"
],
expected_metrics
[
"kv_total_blocks"
],
expected_metrics
[
"num_requests_waiting"
],
expected_metrics
[
"gpu_cache_usage_perc"
],
expected_metrics
[
"gpu_prefix_cache_hit_rate"
],
)
)
await
metrics_publisher
.
create_endpoint
(
kv_listener
)
await
metrics_publisher
.
create_endpoint
(
kv_listener
)
lib/llm/src/kv_router.rs
View file @
3f84cdad
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
dynamo_runtime
::{
component
::
Component
,
component
::
Namespace
,
DistributedRuntime
};
use
dynamo_runtime
::{
component
::
Component
,
component
::
Namespace
,
DistributedRuntime
};
use
futures
::
stream
::
StreamExt
;
use
futures
::
stream
::
StreamExt
;
use
std
::
{
sync
::
Arc
,
time
::
Duration
}
;
use
std
::
sync
::
Arc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tracing
;
use
tracing
;
...
@@ -29,7 +29,8 @@ pub mod scoring;
...
@@ -29,7 +29,8 @@ pub mod scoring;
use
crate
::
kv_router
::{
use
crate
::
kv_router
::{
indexer
::{
KvIndexer
,
KvIndexerInterface
,
RouterEvent
},
indexer
::{
KvIndexer
,
KvIndexerInterface
,
RouterEvent
},
scheduler
::{
Endpoint
,
KvScheduler
,
Service
},
metrics_aggregator
::
collect_endpoints
,
scheduler
::
KvScheduler
,
scoring
::
ProcessedEndpoints
,
scoring
::
ProcessedEndpoints
,
};
};
...
@@ -148,73 +149,3 @@ impl KvRouter {
...
@@ -148,73 +149,3 @@ impl KvRouter {
Ok
(
worker_id
)
Ok
(
worker_id
)
}
}
}
}
async
fn
collect_endpoints
(
nats_client
:
dynamo_runtime
::
transports
::
nats
::
Client
,
service_name
:
String
,
ep_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
ProcessedEndpoints
>
,
cancel
:
CancellationToken
,
)
{
loop
{
tokio
::
select!
{
_
=
cancel
.cancelled
()
=>
{
tracing
::
debug!
(
"cancellation token triggered"
);
break
;
}
_
=
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
1
))
=>
{
tracing
::
trace!
(
"collecting endpoints for service: {}"
,
service_name
);
}
}
let
values
=
match
nats_client
.get_endpoints
(
&
service_name
,
Duration
::
from_secs
(
1
))
.await
{
Ok
(
v
)
=>
v
,
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to retrieve endpoints for {}: {:?}"
,
service_name
,
e
);
continue
;
}
};
tracing
::
debug!
(
"values: {:?}"
,
values
);
let
services
:
Vec
<
Service
>
=
values
.into_iter
()
.filter
(|
v
|
!
v
.is_empty
())
.filter_map
(|
v
|
match
serde_json
::
from_slice
::
<
Service
>
(
&
v
)
{
Ok
(
service
)
=>
Some
(
service
),
Err
(
e
)
=>
{
tracing
::
warn!
(
"For value: {:?}
\n
Failed to parse service: {:?}"
,
v
,
e
);
None
}
})
.collect
();
tracing
::
debug!
(
"services: {:?}"
,
services
);
let
endpoints
:
Vec
<
Endpoint
>
=
services
.into_iter
()
.flat_map
(|
s
|
s
.endpoints
)
.filter
(|
s
|
s
.data
.is_some
())
.map
(|
s
|
Endpoint
{
name
:
s
.name
,
subject
:
s
.subject
,
data
:
s
.data
.unwrap
(),
})
.collect
();
tracing
::
debug!
(
"endpoints: {:?}"
,
endpoints
);
tracing
::
trace!
(
"found {} endpoints for service: {}"
,
endpoints
.len
(),
service_name
);
let
processed
=
ProcessedEndpoints
::
new
(
endpoints
);
// process endpoints into
if
ep_tx
.send
(
processed
)
.await
.is_err
()
{
tracing
::
trace!
(
"failed to send processed endpoints; shutting down"
);
break
;
}
}
}
lib/llm/src/kv_router/indexer.rs
View file @
3f84cdad
...
@@ -588,6 +588,10 @@ impl KvIndexer {
...
@@ -588,6 +588,10 @@ impl KvIndexer {
}
}
}
}
pub
fn
block_size
(
&
self
)
->
usize
{
self
.kv_block_size
}
pub
fn
new
(
token
:
CancellationToken
,
kv_block_size
:
usize
)
->
Self
{
pub
fn
new
(
token
:
CancellationToken
,
kv_block_size
:
usize
)
->
Self
{
Self
::
new_with_frequency
(
token
,
None
,
kv_block_size
)
Self
::
new_with_frequency
(
token
,
None
,
kv_block_size
)
}
}
...
@@ -775,6 +779,10 @@ impl KvIndexerSharded {
...
@@ -775,6 +779,10 @@ impl KvIndexerSharded {
}
}
}
}
pub
fn
block_size
(
&
self
)
->
usize
{
self
.kv_block_size
}
pub
fn
new
(
token
:
CancellationToken
,
num_shards
:
usize
,
kv_block_size
:
usize
)
->
Self
{
pub
fn
new
(
token
:
CancellationToken
,
num_shards
:
usize
,
kv_block_size
:
usize
)
->
Self
{
Self
::
new_with_frequency
(
token
,
num_shards
,
None
,
kv_block_size
)
Self
::
new_with_frequency
(
token
,
num_shards
,
None
,
kv_block_size
)
}
}
...
...
lib/llm/src/kv_router/metrics_aggregator.rs
View file @
3f84cdad
...
@@ -80,25 +80,24 @@ impl KvMetricsAggregator {
...
@@ -80,25 +80,24 @@ impl KvMetricsAggregator {
}
}
}
}
async
fn
collect_endpoints
(
pub
async
fn
collect_endpoints
(
nats_client
:
dynamo_runtime
::
transports
::
nats
::
Client
,
nats_client
:
dynamo_runtime
::
transports
::
nats
::
Client
,
service_name
:
String
,
service_name
:
String
,
ep_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
ProcessedEndpoints
>
,
ep_tx
:
tokio
::
sync
::
mpsc
::
Sender
<
ProcessedEndpoints
>
,
cancel
:
CancellationToken
,
cancel
:
CancellationToken
,
)
{
)
{
let
backoff_delay
=
Duration
::
from_millis
(
100
);
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
_
=
cancel
.cancelled
()
=>
{
_
=
cancel
.cancelled
()
=>
{
tracing
::
debug!
(
"cancellation token triggered"
);
tracing
::
debug!
(
"cancellation token triggered"
);
break
;
break
;
}
}
_
=
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
1
)
)
=>
{
_
=
tokio
::
time
::
sleep
(
backoff_delay
)
=>
{
tracing
::
trace!
(
"collecting endpoints for service: {}"
,
service_name
);
tracing
::
trace!
(
"collecting endpoints for service: {}"
,
service_name
);
}
}
let
values
=
match
nats_client
let
values
=
match
nats_client
.get_endpoints
(
&
service_name
,
Duration
::
from_
secs
(
1
))
.get_endpoints
(
&
service_name
,
Duration
::
from_
millis
(
300
))
.await
.await
{
{
Ok
(
v
)
=>
v
,
Ok
(
v
)
=>
v
,
...
@@ -146,4 +145,6 @@ async fn collect_endpoints(
...
@@ -146,4 +145,6 @@ async fn collect_endpoints(
break
;
break
;
}
}
}
}
}
}
}
}
lib/llm/src/kv_router/protocols.rs
View file @
3f84cdad
...
@@ -21,6 +21,12 @@ pub struct ForwardPassMetrics {
...
@@ -21,6 +21,12 @@ pub struct ForwardPassMetrics {
pub
request_total_slots
:
u64
,
pub
request_total_slots
:
u64
,
pub
kv_active_blocks
:
u64
,
pub
kv_active_blocks
:
u64
,
pub
kv_total_blocks
:
u64
,
pub
kv_total_blocks
:
u64
,
// integer from 0 to large number
pub
num_requests_waiting
:
u64
,
// percentage represented as a float from 0 to 1
pub
gpu_cache_usage_perc
:
f32
,
// percentage represented as a float from 0 to 1
pub
gpu_prefix_cache_hit_rate
:
f32
,
}
}
/// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
/// A [`BlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment