mock_worker.rs 5.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use async_nats::service::endpoint::Stats;
Neelay Shah's avatar
Neelay Shah committed
17
use dynamo_llm::kv_router::{
18
19
    protocols::ForwardPassMetrics, scheduler::KVHitRateEvent, KV_HIT_RATE_SUBJECT,
};
Neelay Shah's avatar
Neelay Shah committed
20
use dynamo_runtime::{
21
    component::Namespace,
22
23
24
25
26
27
    logging,
    pipeline::{
        async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
        ResponseStream, SingleIn,
    },
    protocols::annotated::Annotated,
28
29
30
    stream,
    traits::events::EventPublisher,
    DistributedRuntime, Result, Runtime, Worker,
31
};
32
33
use rand::Rng;
use std::sync::Arc;
34
use tokio::time::{interval, Duration};
35
36
37
38
39
40
41
42
43
44
45
46

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
}

async fn app(runtime: Runtime) -> Result<()> {
    let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
    backend(distributed).await
}

47
struct MockRequestHandler {}
48

49
impl MockRequestHandler {
50
51
52
53
54
55
    fn new() -> Arc<Self> {
        Arc::new(Self {})
    }
}

#[async_trait]
56
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for MockRequestHandler {
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
        let (data, ctx) = input.into_parts();

        let chars = data
            .chars()
            .map(|c| Annotated::from_data(c.to_string()))
            .collect::<Vec<_>>();

        let stream = stream::iter(chars);

        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}

71
// FIXME: These events are just for testing and may not currently be used.
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/// Spawns a background task that periodically publishes mock KV hit rate events
async fn mock_event_publisher(namespace: Namespace) {
    // NOTE: These events are just for testing, and shouldn't be interpreted
    // in correlation with the stats handler's data:
    // 1. The worker ID associated with the events here won't match the
    // worker ID of the endpoint's service stats handler.
    // 2. These events aren't coming through the KV Router, so the metrics won't
    // be reflective of the KV Router's performance.
    // 3. The data in these events aren't in sync with the stats handler's
    // ForwardPassMetrics data, so they may not correlate well.
    let worker_id = rand::thread_rng().gen_range(1..=1000);

    let mut interval = interval(Duration::from_secs(1));
    loop {
        interval.tick().await;

        // Generate random KV hit rate event using a new thread_rng each time
        let isl_blocks = rand::thread_rng().gen_range(0..=100);
        let overlap_blocks = rand::thread_rng().gen_range(0..=isl_blocks);

        let event = KVHitRateEvent {
            worker_id,
            isl_blocks,
            overlap_blocks,
        };

        if let Err(e) = namespace.publish(KV_HIT_RATE_SUBJECT, &event).await {
            tracing::warn!("Failed to publish KV hit rate event: {e}");
        } else {
101
            tracing::debug!(
102
103
104
105
106
107
108
109
110
111
112
113
114
                "Published KV hit rate event: worker_id={worker_id}, isl_blocks={isl_blocks}, overlap_blocks={overlap_blocks}, hit_rate={:.2}%",
                (overlap_blocks as f64 / isl_blocks as f64) * 100.0
            );
        }
    }
}

/// Generates mock forward pass metrics for stats handler
fn mock_stats_handler(_stats: Stats) -> serde_json::Value {
    let request_total_slots = 100;
    let request_active_slots = rand::thread_rng().gen_range(0..=request_total_slots);
    let kv_total_blocks = 100;
    let kv_active_blocks = rand::thread_rng().gen_range(0..=kv_total_blocks);
115
116
117
    let num_requests_waiting = rand::thread_rng().gen_range(0..=100);
    let gpu_cache_usage_perc = rand::thread_rng().gen_range(0.0..=1.0);
    let gpu_prefix_cache_hit_rate = rand::thread_rng().gen_range(0.0..=1.0);
118
119
120
121
122
    let stats = ForwardPassMetrics {
        request_active_slots,
        request_total_slots,
        kv_active_blocks,
        kv_total_blocks,
123
124
125
        num_requests_waiting,
        gpu_cache_usage_perc,
        gpu_prefix_cache_hit_rate,
126
    };
127
    tracing::info!("Stats: {stats:?}");
128
129
130
    serde_json::to_value(stats).unwrap()
}

131
async fn backend(runtime: DistributedRuntime) -> Result<()> {
Neelay Shah's avatar
Neelay Shah committed
132
    let namespace = runtime.namespace("dynamo")?;
133
134
135
136
137
138
139
140
    // we must first create a service, then we can attach one more more endpoints
    let component = namespace
        .component("my_component")?
        .service_builder()
        .create()
        .await?;
    let endpoint = component.endpoint("my_endpoint");
    tracing::info!("Starting Mock Worker on Endpoint: {}", endpoint.path());
141
142
143
144
145
146
147

    // Spawn background task for publishing KV hit rate events
    let namespace_clone = namespace.clone();
    tokio::spawn(async move {
        mock_event_publisher(namespace_clone).await;
    });

148
    // Attach an ingress to the engine
149
    let ingress = Ingress::for_engine(MockRequestHandler::new())?;
150

151
152
    // Make the ingress discoverable via a component service
    endpoint
153
154
        .endpoint_builder()
        // Dummy stats handler to demonstrate how to attach a custom stats handler
155
        .stats_handler(mock_stats_handler)
156
157
158
159
        .handler(ingress)
        .start()
        .await
}