main.rs 8.21 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Count is a metrics aggregator designed to operate within a namespace and collect
//! metrics from all workers.
//!
//! Metrics will collect for now:
//!
//! - LLM Worker Load:Capacity
//!   - These metrics will be scraped by the LLM NATS Service API's stats request
//!   - Request Slots: [Active, Total]
//!   - KV Cache Blocks: [Active, Total]

26
use clap::Parser;
Ryan Olson's avatar
Ryan Olson committed
27
28
use serde::{Deserialize, Serialize};

Neelay Shah's avatar
Neelay Shah committed
29
use triton_distributed_runtime::{
Ryan Olson's avatar
Ryan Olson committed
30
31
32
33
34
35
    error, logging,
    traits::events::EventPublisher,
    utils::{Duration, Instant},
    DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/// CLI arguments for the count application
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
    /// Component to scrape metrics from
    #[arg(long)]
    component: String,

    /// Endpoint to scrape metrics from
    #[arg(long)]
    endpoint: String,

    /// Namespace to operate in
    #[arg(long, env = "TRD_NAMESPACE", default_value = "triton-init")]
    namespace: String,

    /// Polling interval in seconds (minimum 1 second)
    #[arg(long, default_value = "2")]
    poll_interval: u64,
}
Ryan Olson's avatar
Ryan Olson committed
56

57
58
59
60
fn get_config(args: &Args) -> Result<LLMWorkerLoadCapacityConfig> {
    if args.component.is_empty() {
        return Err(error!("Component name cannot be empty"));
    }
Ryan Olson's avatar
Ryan Olson committed
61

62
63
    if args.endpoint.is_empty() {
        return Err(error!("Endpoint name cannot be empty"));
Ryan Olson's avatar
Ryan Olson committed
64
65
    }

66
67
    if args.poll_interval < 1 {
        return Err(error!("Polling interval must be at least 1 second"));
Ryan Olson's avatar
Ryan Olson committed
68
69
70
    }

    Ok(LLMWorkerLoadCapacityConfig {
71
72
        component_name: args.component.clone(),
        endpoint_name: args.endpoint.clone(),
Ryan Olson's avatar
Ryan Olson committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    })
}

// we will scrape the service_name and extract the endpoint_name metrics
// we will bcast them as {namespace}.events.l2c.{service_name}.{endpoint_name}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMWorkerLoadCapacityConfig {
    component_name: String,
    endpoint_name: String,
}

/// LLM Worker Load Capacity Metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMWorkerLoadCapacity {
    pub requests_active_slots: u32,
    pub requests_total_slots: u32,
    pub kv_blocks_active: u32,
    pub kv_blocks_total: u32,
}

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
96
    worker.execute(app)
Ryan Olson's avatar
Ryan Olson committed
97
98
99
}

// TODO - refactor much of this back into the library
100
101
async fn app(runtime: Runtime) -> Result<()> {
    let args = Args::parse();
Ryan Olson's avatar
Ryan Olson committed
102
    // we will start by assuming that there is no oscar and no planner
103
104
105
    // to that end, we will use CLI args to get a singular config for scraping a single backend
    let config = get_config(&args)?;
    tracing::info!("Config: {config:?}");
Ryan Olson's avatar
Ryan Olson committed
106
107
108

    let drt = DistributedRuntime::from_settings(runtime.clone()).await?;

109
    let namespace = drt.namespace(args.namespace)?;
Ryan Olson's avatar
Ryan Olson committed
110
111
112
113
114
    let component = namespace.component("count")?;

    // there should only be one count
    // check {component.etcd_path()}/instance for existing instances
    let key = format!("{}/instance", component.etcd_path());
115
    tracing::info!("Creating unique instance of Count at {key}");
Ryan Olson's avatar
Ryan Olson committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    drt.etcd_client()
        .kv_create(
            key,
            serde_json::to_vec_pretty(&config)?,
            Some(drt.primary_lease().id()),
        )
        .await
        .context("Unable to create unique instance of Count; possibly one already exists")?;

    let target = namespace.component(&config.component_name)?;
    let target_endpoint = target.endpoint(&config.endpoint_name);

    let service_name = target.service_name();
    let service_subject = target_endpoint.subject();
130
    tracing::info!("Scraping service {service_name} and filtering on subject {service_subject}");
Ryan Olson's avatar
Ryan Olson committed
131
132
133
134
135
136
137

    let token = drt.primary_lease().child_token();

    let address = format!("{}.{}", config.component_name, config.endpoint_name,);
    let event_name = format!("l2c.{}", address);

    loop {
138
        let next = Instant::now() + Duration::from_secs(args.poll_interval);
Ryan Olson's avatar
Ryan Olson committed
139
140
141

        // collect stats from each backend
        let stream = target.scrape_stats(Duration::from_secs(1)).await?;
142
        tracing::debug!("Scraped Stats Stream: {stream:?}");
Ryan Olson's avatar
Ryan Olson committed
143
144
145
146
147
148
149

        // filter the stats by the service subject
        let endpoints = stream
            .into_endpoints()
            .filter(|e| e.subject.starts_with(&service_subject))
            .collect::<Vec<_>>();

150
151
152
153
154
        tracing::debug!("Endpoints: {endpoints:?}");
        if endpoints.is_empty() {
            tracing::warn!("No endpoints found matching subject {}", service_subject);
        }

Ryan Olson's avatar
Ryan Olson committed
155
156
157
158
159
160
161
162
        // extract the custom data from the stats and try to decode it as LLMWorkerLoadCapacity
        let metrics = endpoints
            .iter()
            .filter_map(|e| match e.data.clone() {
                Some(metrics) => metrics.decode::<LLMWorkerLoadCapacity>().ok(),
                None => None,
            })
            .collect::<Vec<_>>();
163
        tracing::debug!("Metrics: {metrics:?}");
Ryan Olson's avatar
Ryan Olson committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

        // parse the endpoint ids
        // the ids are the last part of the subject in hexadecimal
        // form a list of tuples (kv_blocks_total - kv_blocks_active, requests_total_slots - requests_active_slots, id)
        // this tuple represent the remaining capacity of each endpoint
        let capacity_with_ids = metrics
            .iter()
            .zip(endpoints.iter())
            .filter_map(|(m, e)| {
                e.id().ok().map(|id| {
                    (
                        m.kv_blocks_total - m.kv_blocks_active,
                        m.requests_total_slots - m.requests_active_slots,
                        id,
                    )
                })
            })
            .collect::<Vec<_>>();

        // compute mean / std of LLMWorkerLoadCapacity
        let load_values: Vec<f64> = metrics.iter().map(|x| x.kv_blocks_active as f64).collect();
        let load_avg = load_values.iter().sum::<f64>() / load_values.len() as f64;
        let variance = load_values
            .iter()
            .map(|&x| (x - load_avg).powi(2))
            .sum::<f64>()
            / load_values.len() as f64;
        let load_std = variance.sqrt();

        let processed = ProcessedEndpoints {
            capacity_with_ids,
            load_avg,
            load_std,
            address: address.clone(),
        };

        // publish using the namespace event plane
201
202
203
        tracing::info!(
            "Publishing event {event_name} on namespace {namespace:?} with {processed:?}"
        );
Ryan Olson's avatar
Ryan Olson committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        namespace.publish(&event_name, &processed).await?;

        // wait until cancelled or the next tick
        match tokio::time::timeout_at(next, token.cancelled()).await {
            Ok(_) => break,
            Err(_) => {
                // timeout, we continue
                continue;
            }
        }
    }

    Ok(())
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedEndpoints {
    /// (kv_blocks_total - kv_blocks_active, requests_total_slots - requests_active_slots, id)
    pub capacity_with_ids: Vec<(u32, u32, i64)>,

    /// kv_blocks_active average
    pub load_avg: f64,

    /// kv_blocks_active standard deviation
    pub load_std: f64,

    /// {component}.{endpoint}
    pub address: String,
}
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

#[cfg(test)]
mod tests {
    use super::*;
    use std::env;

    #[test]
    fn test_namespace_from_env() {
        env::set_var("TRD_NAMESPACE", "test-namespace");

        // Parse args with no explicit namespace
        let args = Args::parse_from(["count", "--component", "comp", "--endpoint", "end"]);

        // Verify namespace was taken from environment variable
        assert_eq!(args.namespace, "test-namespace");
    }
}