mock_worker.rs 5.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
use dynamo_llm::kv_router::{
17
18
    protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT,
};
Neelay Shah's avatar
Neelay Shah committed
19
use dynamo_runtime::{
20
    component::{service::EndpointStats, Namespace},
21
22
23
24
25
26
    logging,
    pipeline::{
        async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
        ResponseStream, SingleIn,
    },
    protocols::annotated::Annotated,
27
28
29
    stream,
    traits::events::EventPublisher,
    DistributedRuntime, Result, Runtime, Worker,
30
};
31
32
use rand::Rng;
use std::sync::Arc;
33
use tokio::time::{interval, Duration};
34
35
36
37
38
39
40
41
42
43
44
45

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
}

async fn app(runtime: Runtime) -> Result<()> {
    let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
    backend(distributed).await
}

46
struct MockRequestHandler {}
47

48
impl MockRequestHandler {
49
50
51
52
53
54
    fn new() -> Arc<Self> {
        Arc::new(Self {})
    }
}

#[async_trait]
55
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for MockRequestHandler {
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
        let (data, ctx) = input.into_parts();

        let chars = data
            .chars()
            .map(|c| Annotated::from_data(c.to_string()))
            .collect::<Vec<_>>();

        let stream = stream::iter(chars);

        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}

70
// FIXME: These events are just for testing and may not currently be used.
71
72
73
74
75
76
77
78
79
80
/// Spawns a background task that periodically publishes mock KV hit rate events
async fn mock_event_publisher(namespace: Namespace) {
    // NOTE: These events are just for testing, and shouldn't be interpreted
    // in correlation with the stats handler's data:
    // 1. The worker ID associated with the events here won't match the
    // worker ID of the endpoint's service stats handler.
    // 2. These events aren't coming through the KV Router, so the metrics won't
    // be reflective of the KV Router's performance.
    // 3. The data in these events aren't in sync with the stats handler's
    // ForwardPassMetrics data, so they may not correlate well.
81
    let worker_id = rand::rng().random_range(1..=1000);
82
83
84
85
86
87

    let mut interval = interval(Duration::from_secs(1));
    loop {
        interval.tick().await;

        // Generate random KV hit rate event using a new thread_rng each time
88
89
        let isl_blocks = rand::rng().random_range(0..=100);
        let overlap_blocks = rand::rng().random_range(0..=isl_blocks);
90
91
92
93
94
95
96
97
98
99

        let event = KVHitRateEvent {
            worker_id,
            isl_blocks,
            overlap_blocks,
        };

        if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
            tracing::warn!("Failed to publish KV hit rate event: {e}");
        } else {
100
            tracing::debug!(
101
102
103
104
105
106
107
108
                "Published KV hit rate event: worker_id={worker_id}, isl_blocks={isl_blocks}, overlap_blocks={overlap_blocks}, hit_rate={:.2}%",
                (overlap_blocks as f64 / isl_blocks as f64) * 100.0
            );
        }
    }
}

/// Generates mock forward pass metrics for stats handler
109
fn mock_stats_handler(_stats: EndpointStats) -> serde_json::Value {
110
    let request_total_slots = 100;
111
    let request_active_slots = rand::rng().random_range(0..=request_total_slots);
112
    let kv_total_blocks = 100;
113
114
115
116
    let kv_active_blocks = rand::rng().random_range(0..=kv_total_blocks);
    let num_requests_waiting = rand::rng().random_range(0..=100);
    let gpu_cache_usage_perc = rand::rng().random_range(0.0..=1.0);
    let gpu_prefix_cache_hit_rate = rand::rng().random_range(0.0..=1.0);
117
118
119
120
121
    let stats = ForwardPassMetrics {
        request_active_slots,
        request_total_slots,
        kv_active_blocks,
        kv_total_blocks,
122
123
124
        num_requests_waiting,
        gpu_cache_usage_perc,
        gpu_prefix_cache_hit_rate,
125
    };
126
    tracing::info!("Stats: {stats:?}");
127
128
129
    serde_json::to_value(stats).unwrap()
}

130
async fn backend(runtime: DistributedRuntime) -> Result<()> {
Neelay Shah's avatar
Neelay Shah committed
131
    let namespace = runtime.namespace("dynamo")?;
132
133
134
135
136
137
138
139
    // we must first create a service, then we can attach one more more endpoints
    let component = namespace
        .component("my_component")?
        .service_builder()
        .create()
        .await?;
    let endpoint = component.endpoint("my_endpoint");
    tracing::info!("Starting Mock Worker on Endpoint: {}", endpoint.path());
140
141
142
143
144
145
146

    // Spawn background task for publishing KV hit rate events
    let namespace_clone = namespace.clone();
    tokio::spawn(async move {
        mock_event_publisher(namespace_clone).await;
    });

147
    // Attach an ingress to the engine
148
    let ingress = Ingress::for_engine(MockRequestHandler::new())?;
149

150
151
    // Make the ingress discoverable via a component service
    endpoint
152
153
        .endpoint_builder()
        // Dummy stats handler to demonstrate how to attach a custom stats handler
154
        .stats_handler(mock_stats_handler)
155
156
157
158
        .handler(ingress)
        .start()
        .await
}