mock_worker.rs 5.41 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

Neelay Shah's avatar
Neelay Shah committed
4
use dynamo_llm::kv_router::{
5
    KV_HIT_RATE_SUBJECT,
6
7
    protocols::{ForwardPassMetrics, KvStats, WorkerStats},
    scheduler::KVHitRateEvent,
8
};
Neelay Shah's avatar
Neelay Shah committed
9
use dynamo_runtime::{
10
11
    DistributedRuntime, Result, Runtime, Worker,
    component::{Namespace, service::EndpointStats},
12
13
    logging,
    pipeline::{
14
15
        AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
        async_trait, network::Ingress,
16
17
    },
    protocols::annotated::Annotated,
18
19
    stream,
    traits::events::EventPublisher,
20
};
21
22
use rand::Rng;
use std::sync::Arc;
23
use tokio::time::{Duration, interval};
24
25
26
27
28
29
30
31
32
33
34
35

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
}

async fn app(runtime: Runtime) -> Result<()> {
    let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
    backend(distributed).await
}

36
struct MockRequestHandler {}
37

38
impl MockRequestHandler {
39
40
41
42
43
44
    fn new() -> Arc<Self> {
        Arc::new(Self {})
    }
}

#[async_trait]
45
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for MockRequestHandler {
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
        let (data, ctx) = input.into_parts();

        let chars = data
            .chars()
            .map(|c| Annotated::from_data(c.to_string()))
            .collect::<Vec<_>>();

        let stream = stream::iter(chars);

        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}

60
// FIXME: These events are just for testing and may not currently be used.
61
62
63
64
65
66
67
68
69
70
/// Spawns a background task that periodically publishes mock KV hit rate events
async fn mock_event_publisher(namespace: Namespace) {
    // NOTE: These events are just for testing, and shouldn't be interpreted
    // in correlation with the stats handler's data:
    // 1. The worker ID associated with the events here won't match the
    // worker ID of the endpoint's service stats handler.
    // 2. These events aren't coming through the KV Router, so the metrics won't
    // be reflective of the KV Router's performance.
    // 3. The data in these events aren't in sync with the stats handler's
    // ForwardPassMetrics data, so they may not correlate well.
71
    let worker_id = rand::rng().random_range(1..=1000);
72
73
74
75
76
77

    let mut interval = interval(Duration::from_secs(1));
    loop {
        interval.tick().await;

        // Generate random KV hit rate event using a new thread_rng each time
78
79
        let isl_blocks = rand::rng().random_range(0..=100);
        let overlap_blocks = rand::rng().random_range(0..=isl_blocks);
80
81
82
83

        let event = KVHitRateEvent {
            worker_id,
            isl_blocks,
84
            overlap_blocks: overlap_blocks as u32,
85
86
87
88
89
        };

        if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
            tracing::warn!("Failed to publish KV hit rate event: {e}");
        } else {
90
            tracing::debug!(
91
92
93
94
95
96
97
98
                "Published KV hit rate event: worker_id={worker_id}, isl_blocks={isl_blocks}, overlap_blocks={overlap_blocks}, hit_rate={:.2}%",
                (overlap_blocks as f64 / isl_blocks as f64) * 100.0
            );
        }
    }
}

/// Generates mock forward pass metrics for stats handler
99
fn mock_stats_handler(_stats: EndpointStats) -> serde_json::Value {
100
    let request_total_slots = 100;
101
    let request_active_slots = rand::rng().random_range(0..=request_total_slots);
102
    let kv_total_blocks = 100;
103
104
105
106
    let kv_active_blocks = rand::rng().random_range(0..=kv_total_blocks);
    let num_requests_waiting = rand::rng().random_range(0..=100);
    let gpu_cache_usage_perc = rand::rng().random_range(0.0..=1.0);
    let gpu_prefix_cache_hit_rate = rand::rng().random_range(0.0..=1.0);
107
108

    let worker_stats = WorkerStats {
109
        data_parallel_rank: None, // Default for backwards compatibility
110
111
        request_active_slots,
        request_total_slots,
112
113
114
115
        num_requests_waiting,
    };

    let kv_stats = KvStats {
116
117
        kv_active_blocks,
        kv_total_blocks,
118
119
        gpu_cache_usage_perc,
        gpu_prefix_cache_hit_rate,
120
    };
121
122
123
124
125
126
127
128

    let spec_decode_stats = None;

    let stats = ForwardPassMetrics {
        worker_stats,
        kv_stats,
        spec_decode_stats,
    };
129
    tracing::info!("Stats: {stats:?}");
130
131
132
    serde_json::to_value(stats).unwrap()
}

133
async fn backend(runtime: DistributedRuntime) -> Result<()> {
Neelay Shah's avatar
Neelay Shah committed
134
    let namespace = runtime.namespace("dynamo")?;
135
136
    // we must first create a service, then we can attach one more more endpoints
    let component = namespace
137
        .component("MyComponent")?
138
139
140
141
142
        .service_builder()
        .create()
        .await?;
    let endpoint = component.endpoint("my_endpoint");
    tracing::info!("Starting Mock Worker on Endpoint: {}", endpoint.path());
143
144
145
146
147
148
149

    // Spawn background task for publishing KV hit rate events
    let namespace_clone = namespace.clone();
    tokio::spawn(async move {
        mock_event_publisher(namespace_clone).await;
    });

150
    // Attach an ingress to the engine
151
    let ingress = Ingress::for_engine(MockRequestHandler::new())?;
152

153
154
    // Make the ingress discoverable via a component service
    endpoint
155
156
        .endpoint_builder()
        // Dummy stats handler to demonstrate how to attach a custom stats handler
157
        .stats_handler(mock_stats_handler)
158
159
160
161
        .handler(ingress)
        .start()
        .await
}