"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "dd7646ef33354f342959b22ec07e5f4a4f1d3da0"
Commit e159e53f authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: expose KV routing components for easier router customization (#15)

parent ea78a424
...@@ -23,7 +23,7 @@ import uvloop ...@@ -23,7 +23,7 @@ import uvloop
from common.protocol import Tokens from common.protocol import Tokens
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from dynemo.llm import KvRouter from dynemo.llm import KvIndexer, KvMetricsAggregator, KvRouter
from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker
WorkerId = str WorkerId = str
...@@ -78,6 +78,60 @@ class Router: ...@@ -78,6 +78,60 @@ class Router:
) )
class CustomRouter:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
def _cost_function(self, scores, metrics):
# naive cost function for demonstration purposes
current_best = ("", 0)
for worker_id, score in scores.scores.items():
if score > current_best[1]:
current_best = (worker_id, score)
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
)
print(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
)
return current_best[0]
@dynemo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
@dynemo_worker() @dynemo_worker()
async def worker(runtime: DistributedRuntime, args: Namespace): async def worker(runtime: DistributedRuntime, args: Namespace):
""" """
...@@ -116,10 +170,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -116,10 +170,17 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
router_component = runtime.namespace("dynemo").component("router") router_component = runtime.namespace("dynemo").component("router")
await router_component.create_service() await router_component.create_service()
router = KvRouter(runtime, kv_listener)
endpoint = router_component.endpoint("generate") endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if args.custom_router:
indexer = KvIndexer(kv_listener)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
router = KvRouter(runtime, kv_listener)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -147,6 +208,12 @@ if __name__ == "__main__": ...@@ -147,6 +208,12 @@ if __name__ == "__main__":
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served", help="Model that is being served",
) )
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(worker(args)) asyncio.run(worker(args))
...@@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -70,6 +70,11 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?; m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::backend::Backend>()?; m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?;
m.add_class::<llm::kv::KvIndexer>()?;
m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
......
...@@ -13,7 +13,11 @@ ...@@ -13,7 +13,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashMap;
use super::*; use super::*;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use tracing;
#[pyclass] #[pyclass]
pub(crate) struct KvRouter { pub(crate) struct KvRouter {
...@@ -106,3 +110,160 @@ impl KvMetricsPublisher { ...@@ -106,3 +110,160 @@ impl KvMetricsPublisher {
.map_err(to_pyerr) .map_err(to_pyerr)
} }
} }
#[pyclass]
#[derive(Clone)]
pub(crate) struct OverlapScores {
inner: llm_rs::kv_router::indexer::OverlapScores,
}
#[pymethods]
impl OverlapScores {
#[getter]
fn scores(&self) -> HashMap<llm_rs::kv_router::indexer::WorkerId, u32> {
self.inner.scores.clone()
}
#[getter]
fn frequencies(&self) -> Vec<usize> {
self.inner.frequencies.clone()
}
}
#[pyclass]
pub(crate) struct KvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
}
#[pymethods]
impl KvIndexer {
#[new]
fn new(component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let kv_subject = component
.inner
.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT);
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
)
.into();
let mut kv_events_rx = component
.inner
.drt()
.nats_client()
.client()
.subscribe(kv_subject)
.await
.map_err(to_pyerr)?;
let kv_events_tx = inner.event_sender();
// [FIXME] this is the added functionality to the indexer to subscribe to kv events,
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
Ok(Self { inner })
})
}
fn find_matches_for_request<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
_lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let indexer = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let rs_overlap_scores = indexer
.find_matches_for_request(token_ids.as_slice())
.await
.map_err(to_pyerr)?;
Ok(OverlapScores {
inner: rs_overlap_scores,
})
})
}
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct EndpointKvMetrics {
#[pyo3(get, set)]
pub worker_id: i64,
#[pyo3(get, set)]
pub request_active_slots: u64,
#[pyo3(get, set)]
pub request_total_slots: u64,
#[pyo3(get, set)]
pub kv_active_blocks: u64,
#[pyo3(get, set)]
pub kv_total_blocks: u64,
}
#[pyclass]
#[derive(Clone)]
pub(crate) struct AggregatedMetrics {
#[pyo3(get, set)]
pub endpoints: Vec<EndpointKvMetrics>,
#[pyo3(get, set)]
pub load_avg: f64,
#[pyo3(get, set)]
pub load_std: f64,
}
#[pyclass]
pub(crate) struct KvMetricsAggregator {
inner: Arc<llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator>,
}
#[pymethods]
impl KvMetricsAggregator {
#[new]
fn new(component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::metrics_aggregator::KvMetricsAggregator::new(
component.inner.clone(),
component.inner.drt().runtime().child_token(),
)
.await;
Ok(Self {
inner: inner.into(),
})
})
}
fn get_metrics<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let endpoints = self.inner.get_endpoints();
let endpoint_kv_metrics = endpoints
.endpoints
.iter()
.map(|x| EndpointKvMetrics {
worker_id: x.worker_id(),
request_active_slots: x.data.request_active_slots,
request_total_slots: x.data.request_total_slots,
kv_active_blocks: x.data.kv_active_blocks,
kv_total_blocks: x.data.kv_total_blocks,
})
.collect();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(AggregatedMetrics {
endpoints: endpoint_kv_metrics,
load_avg: endpoints.load_avg,
load_std: endpoints.load_std,
})
})
}
}
...@@ -233,3 +233,55 @@ class Backend: ...@@ -233,3 +233,55 @@ class Backend:
Start the backend engine and requests to the downstream LLM engine Start the backend engine and requests to the downstream LLM engine
""" """
... ...
class OverlapScores:
"""
A collection of prefix matching scores of workers for a given token ids.
'scores' is a map of worker id to the score which is the number of matching blocks.
"""
...
class KvIndexer:
"""
A KV Indexer that tracks KV Events emitted by workers. Events include add_block and remove_block.
"""
...
def __init__(self, component: Component) -> None:
"""
Create a `KvIndexer` object
"""
def find_matches_for_request(self, token_ids: List[int], lora_id: int) -> OverlapScores:
"""
Return the overlapping scores of workers for the given token ids.
"""
...
class AggregatedMetrics:
"""
A collection of metrics of the endpoints
"""
...
class KvMetricsAggregator:
"""
A metrics aggregator will collect KV metrics of the endpoints.
"""
...
def __init__(self, component: Component) -> None:
"""
Create a `KvMetricsAggregator` object
"""
def get_metrics(self) -> AggregatedMetrics:
"""
Return the aggregated metrics of the endpoints.
"""
...
...@@ -13,5 +13,7 @@ ...@@ -13,5 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dynemo._core import KvIndexer as KvIndexer
from dynemo._core import KvMetricsAggregator as KvMetricsAggregator
from dynemo._core import KvMetricsPublisher as KvMetricsPublisher from dynemo._core import KvMetricsPublisher as KvMetricsPublisher
from dynemo._core import KvRouter as KvRouter from dynemo._core import KvRouter as KvRouter
...@@ -21,6 +21,7 @@ use tokio_util::sync::CancellationToken; ...@@ -21,6 +21,7 @@ use tokio_util::sync::CancellationToken;
use tracing; use tracing;
pub mod indexer; pub mod indexer;
pub mod metrics_aggregator;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
pub mod scheduler; pub mod scheduler;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::{Arc, Mutex};
pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::scheduler::{Endpoint, Service};
use crate::kv_router::ProcessedEndpoints;
use dynemo_runtime::component::Component;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
pub struct KvMetricsAggregator {
pub service_name: String,
pub endpoints: Arc<Mutex<ProcessedEndpoints>>,
}
impl KvMetricsAggregator {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128);
tokio::spawn(collect_endpoints(
component.drt().nats_client().clone(),
component.service_name(),
ep_tx,
cancellation_token.clone(),
));
tracing::trace!("awaiting the start of the background endpoint subscriber");
let endpoints = Arc::new(Mutex::new(ProcessedEndpoints::default()));
let endpoints_clone = endpoints.clone();
tokio::spawn(async move {
tracing::debug!("scheduler background task started");
loop {
match ep_rx.recv().await {
Some(endpoints) => match endpoints_clone.lock() {
Ok(mut shared_endpoint) => {
*shared_endpoint = endpoints;
}
Err(e) => {
tracing::error!("Failed to acquire lock on endpoints: {:?}", e);
}
},
None => {
tracing::warn!("endpoint subscriber shutdown");
break;
}
};
}
tracing::trace!("background endpoint subscriber shutting down");
});
Self {
service_name: component.service_name(),
endpoints,
}
}
pub fn get_endpoints(&self) -> ProcessedEndpoints {
match self.endpoints.lock() {
Ok(endpoints) => endpoints.clone(),
Err(e) => {
tracing::error!("Failed to acquire lock on endpoints: {:?}", e);
ProcessedEndpoints::default()
}
}
}
}
async fn collect_endpoints(
nats_client: dynemo_runtime::transports::nats::Client,
service_name: String,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel.cancelled() => {
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
tracing::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
);
let processed = ProcessedEndpoints::new(endpoints);
if ep_tx.send(processed).await.is_err() {
tracing::trace!("failed to send processed endpoints; shutting down");
break;
}
}
}
...@@ -20,7 +20,7 @@ use std::collections::HashSet; ...@@ -20,7 +20,7 @@ use std::collections::HashSet;
use crate::kv_router::scheduler::Endpoint; use crate::kv_router::scheduler::Endpoint;
#[derive(Debug, Default, Serialize, Deserialize)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct ProcessedEndpoints { pub struct ProcessedEndpoints {
pub endpoints: Vec<Endpoint>, pub endpoints: Vec<Endpoint>,
pub worker_ids: Vec<i64>, pub worker_ids: Vec<i64>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment