mock_worker.rs 5.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use async_nats::service::endpoint::Stats;
Neelay Shah's avatar
Neelay Shah committed
17
use dynamo_llm::kv_router::{
18
19
    protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT,
};
Neelay Shah's avatar
Neelay Shah committed
20
use dynamo_runtime::{
21
    component::Namespace,
22
23
24
25
26
27
    logging,
    pipeline::{
        async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
        ResponseStream, SingleIn,
    },
    protocols::annotated::Annotated,
28
29
30
    stream,
    traits::events::EventPublisher,
    DistributedRuntime, Result, Runtime, Worker,
31
};
32
33
use rand::Rng;
use std::sync::Arc;
34
use tokio::time::{interval, Duration};
35
36
37
38
39
40
41
42
43
44
45
46

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
}

async fn app(runtime: Runtime) -> Result<()> {
    let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
    backend(distributed).await
}

47
struct MockRequestHandler {}
48

49
impl MockRequestHandler {
50
51
52
53
54
55
    fn new() -> Arc<Self> {
        Arc::new(Self {})
    }
}

#[async_trait]
56
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for MockRequestHandler {
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
        let (data, ctx) = input.into_parts();

        let chars = data
            .chars()
            .map(|c| Annotated::from_data(c.to_string()))
            .collect::<Vec<_>>();

        let stream = stream::iter(chars);

        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/// Spawns a background task that periodically publishes mock KV hit rate events
async fn mock_event_publisher(namespace: Namespace) {
    // NOTE: These events are just for testing, and shouldn't be interpreted
    // in correlation with the stats handler's data:
    // 1. The worker ID associated with the events here won't match the
    // worker ID of the endpoint's service stats handler.
    // 2. These events aren't coming through the KV Router, so the metrics won't
    // be reflective of the KV Router's performance.
    // 3. The data in these events aren't in sync with the stats handler's
    // ForwardPassMetrics data, so they may not correlate well.
    let worker_id = rand::thread_rng().gen_range(1..=1000);

    let mut interval = interval(Duration::from_secs(1));
    loop {
        interval.tick().await;

        // Generate random KV hit rate event using a new thread_rng each time
        let isl_blocks = rand::thread_rng().gen_range(0..=100);
        let overlap_blocks = rand::thread_rng().gen_range(0..=isl_blocks);

        let event = KVHitRateEvent {
            worker_id,
            isl_blocks,
            overlap_blocks,
        };

        if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
            tracing::warn!("Failed to publish KV hit rate event: {e}");
        } else {
            tracing::info!(
                "Published KV hit rate event: worker_id={worker_id}, isl_blocks={isl_blocks}, overlap_blocks={overlap_blocks}, hit_rate={:.2}%",
                (overlap_blocks as f64 / isl_blocks as f64) * 100.0
            );
        }
    }
}

/// Generates mock forward pass metrics for stats handler
fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
    println!("stats in: {:?}", _stats);
    let request_total_slots = 100;
    let request_active_slots = rand::thread_rng().gen_range(0..=request_total_slots);
    let kv_total_blocks = 100;
    let kv_active_blocks = rand::thread_rng().gen_range(0..=kv_total_blocks);
    let stats = ForwardPassMetrics {
        request_active_slots,
        request_total_slots,
        kv_active_blocks,
        kv_total_blocks,
    };
    println!("stats out: {:?}", stats);
    serde_json::to_value(stats).unwrap()
}

125
async fn backend(runtime: DistributedRuntime) -> Result<()> {
Neelay Shah's avatar
Neelay Shah committed
126
    let namespace = runtime.namespace("dynamo")?;
127
128
129
130
131
132
133

    // Spawn background task for publishing KV hit rate events
    let namespace_clone = namespace.clone();
    tokio::spawn(async move {
        mock_event_publisher(namespace_clone).await;
    });

134
    // attach an ingress to an engine
135
    let ingress = Ingress::for_engine(MockRequestHandler::new())?;
136
137
138

    // make the ingress discoverable via a component service
    // we must first create a service, then we can attach one more more endpoints
139
    namespace
140
141
142
143
144
145
146
        .component("backend")?
        .service_builder()
        .create()
        .await?
        .endpoint("generate")
        .endpoint_builder()
        // Dummy stats handler to demonstrate how to attach a custom stats handler
147
        .stats_handler(mock_stats_handler)
148
149
150
151
        .handler(ingress)
        .start()
        .await
}