main.rs 8.47 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Count is a metrics aggregator designed to operate within a namespace and collect
//! metrics from all workers.
//!
//! Metrics will collect for now:
//!
//! - LLM Worker Load:Capacity
//!   - These metrics will be scraped by the LLM NATS Service API's stats request
//!   - Request Slots: [Active, Total]
//!   - KV Cache Blocks: [Active, Total]
25
26
27
28
//! - KV Hit Rate:
//!   - These metrics will be collected from KV hit rate events published by the KV router
//!   - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//!   - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
29
use clap::Parser;
30
use dynemo_llm::kv_router::scheduler::KVHitRateEvent;
31
use dynemo_runtime::{
Ryan Olson's avatar
Ryan Olson committed
32
    error, logging,
33
    traits::events::{EventPublisher, EventSubscriber},
Ryan Olson's avatar
Ryan Olson committed
34
35
36
    utils::{Duration, Instant},
    DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};
37
38
use futures::stream::StreamExt;
use std::sync::Arc;
Ryan Olson's avatar
Ryan Olson committed
39

40
41
42
43
44
45
// Import from our library
use count::{
    collect_endpoints, extract_metrics, postprocess_metrics, LLMWorkerLoadCapacityConfig,
    PrometheusMetricsServer,
};

46
47
48
49
50
51
52
53
54
55
56
57
58
/// CLI arguments for the count application
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    /// Component to scrape metrics from
    #[arg(long)]
    component: String,

    /// Endpoint to scrape metrics from
    #[arg(long)]
    endpoint: String,

    /// Namespace to operate in
59
    #[arg(long, env = "DYN_NAMESPACE", default_value = "dynemo")]
60
61
62
63
64
65
    namespace: String,

    /// Polling interval in seconds (minimum 1 second)
    #[arg(long, default_value = "2")]
    poll_interval: u64,
}
Ryan Olson's avatar
Ryan Olson committed
66

67
68
69
70
fn get_config(args: &Args) -> Result<LLMWorkerLoadCapacityConfig> {
    if args.component.is_empty() {
        return Err(error!("Component name cannot be empty"));
    }
Ryan Olson's avatar
Ryan Olson committed
71

72
73
    if args.endpoint.is_empty() {
        return Err(error!("Endpoint name cannot be empty"));
Ryan Olson's avatar
Ryan Olson committed
74
75
    }

76
77
    if args.poll_interval < 1 {
        return Err(error!("Polling interval must be at least 1 second"));
Ryan Olson's avatar
Ryan Olson committed
78
79
80
    }

    Ok(LLMWorkerLoadCapacityConfig {
81
82
        component_name: args.component.clone(),
        endpoint_name: args.endpoint.clone(),
Ryan Olson's avatar
Ryan Olson committed
83
84
85
    })
}

86
87
async fn app(runtime: Runtime) -> Result<()> {
    let args = Args::parse();
88
89
    let config = get_config(&args)?;
    tracing::info!("Config: {config:?}");
Ryan Olson's avatar
Ryan Olson committed
90
91
92

    let drt = DistributedRuntime::from_settings(runtime.clone()).await?;

93
    let namespace = drt.namespace(args.namespace)?;
Ryan Olson's avatar
Ryan Olson committed
94
95
    let component = namespace.component("count")?;

96
    // Create unique instance of Count
Ryan Olson's avatar
Ryan Olson committed
97
    let key = format!("{}/instance", component.etcd_path());
98
    tracing::info!("Creating unique instance of Count at {key}");
Ryan Olson's avatar
Ryan Olson committed
99
100
101
102
103
104
105
106
107
    drt.etcd_client()
        .kv_create(
            key,
            serde_json::to_vec_pretty(&config)?,
            Some(drt.primary_lease().id()),
        )
        .await
        .context("Unable to create unique instance of Count; possibly one already exists")?;

108
109
    let target_component = namespace.component(&config.component_name)?;
    let target_endpoint = target_component.endpoint(&config.endpoint_name);
Ryan Olson's avatar
Ryan Olson committed
110

111
    let service_name = target_component.service_name();
Ryan Olson's avatar
Ryan Olson committed
112
    let service_subject = target_endpoint.subject();
113
    tracing::info!("Scraping service {service_name} and filtering on subject {service_subject}");
Ryan Olson's avatar
Ryan Olson committed
114
115

    let token = drt.primary_lease().child_token();
116
    let event_name = format!("l2c.{}.{}", config.component_name, config.endpoint_name);
Ryan Olson's avatar
Ryan Olson committed
117

118
119
    // TODO: Make metrics host/port configurable
    // Initialize Prometheus metrics and start server
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    let metrics_server = PrometheusMetricsServer::new()?;
    // Metrics will be updated concurrently, so protect it with a mutex:
    // - Main loop: Collect and process ForwardPassMetrics at an interval from endpoint stats handlers
    // - Subscription task: Collect and process KVHitRateEvent metrics from the KV router as they are published
    let metrics_server = Arc::new(tokio::sync::Mutex::new(metrics_server));
    metrics_server.lock().await.start(9091);

    // Subscribe to KV hit rate events
    let kv_hit_rate_subject = "kv-hit-rate";
    tracing::info!("Subscribing to KV hit rate events on subject: {kv_hit_rate_subject}");

    // Clone the metrics server and config for the subscription task
    let metrics_server_clone = metrics_server.clone();
    let config_clone = config.clone();
    // Clone the namespace for the subscription task
    let namespace_clone = namespace.clone();

    // Spawn a task to handle KV hit rate events
    tokio::spawn(async move {
        match namespace_clone.subscribe(kv_hit_rate_subject).await {
            Ok(mut subscriber) => {
                tracing::info!("Successfully subscribed to KV hit rate events");

                while let Some(msg) = subscriber.next().await {
                    match serde_json::from_slice::<KVHitRateEvent>(&msg.payload) {
                        Ok(event) => {
                            // TODO: Lower to debug
                            let cache_hit_pct =
                                (event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
                            tracing::info!(
                                "Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
                                event.worker_id,
                                event.isl_blocks,
                                event.overlap_blocks,
                                cache_hit_pct
                            );

                            // Update metrics with the event data
                            let mut metrics = metrics_server_clone.lock().await;
                            metrics.update_kv_hit_rate(
                                &config_clone,
                                event.worker_id,
                                event.isl_blocks,
                                event.overlap_blocks,
                            );
                        }
                        Err(e) => {
                            tracing::warn!("Failed to deserialize KV hit rate event: {:?}", e);
                        }
                    }
                }

                tracing::warn!("KV hit rate event subscription stream ended");
            }
            Err(e) => {
                tracing::error!("Failed to subscribe to KV hit rate events: {:?}", e);
            }
        }
    });
Ryan Olson's avatar
Ryan Olson committed
179
180

    loop {
181
        let next = Instant::now() + Duration::from_secs(args.poll_interval);
Ryan Olson's avatar
Ryan Olson committed
182

183
184
185
186
187
188
        // Collect and process metrics
        let scrape_timeout = Duration::from_secs(1);
        let endpoints =
            collect_endpoints(&target_component, &service_subject, scrape_timeout).await?;
        let metrics = extract_metrics(&endpoints);
        let processed = postprocess_metrics(&metrics, &endpoints);
189
        tracing::debug!("Aggregated metrics: {processed:?}");
Ryan Olson's avatar
Ryan Olson committed
190

191
        // Update Prometheus metrics
192
        metrics_server.lock().await.update(&config, &processed);
193

194
195
196
        // TODO: Enable KV Routers to subscribe to metrics events published here
        // for a single view of the aggregated metrics, as opposed to the current
        // approach where each KV Router computes and published its own metrics.
197
        // Publish metrics event
Ryan Olson's avatar
Ryan Olson committed
198
199
        namespace.publish(&event_name, &processed).await?;

200
        // Wait until cancelled or the next tick
Ryan Olson's avatar
Ryan Olson committed
201
202
        match tokio::time::timeout_at(next, token.cancelled()).await {
            Ok(_) => break,
203
            Err(_) => continue,
Ryan Olson's avatar
Ryan Olson committed
204
205
206
207
208
209
        }
    }

    Ok(())
}

210
211
212
213
fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
Ryan Olson's avatar
Ryan Olson committed
214
}
215
216
217
218
219
220
221
222

#[cfg(test)]
mod tests {
    use super::*;
    use std::env;

    #[test]
    fn test_namespace_from_env() {
223
        env::set_var("DYN_NAMESPACE", "test-namespace");
224
225
226
227
        let args = Args::parse_from(["count", "--component", "comp", "--endpoint", "end"]);
        assert_eq!(args.namespace, "test-namespace");
    }
}