main.rs 6.94 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Count is a metrics aggregator designed to operate within a namespace and collect
//! metrics from all workers.
//!
//! Metrics will collect for now:
//!
//! - LLM Worker Load:Capacity
//!   - These metrics will be scraped by the LLM NATS Service API's stats request
//!   - Request Slots: [Active, Total]
//!   - KV Cache Blocks: [Active, Total]

use serde::{Deserialize, Serialize};

use triton_distributed::{
    error, logging,
    traits::events::EventPublisher,
    utils::{Duration, Instant},
    DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};

use tracing as log;

// enum MetricTypes {
//     LLMWorkerLoadCapacity(LLMWorkerLoadCapacityConfig),
// }

fn get_config() -> Result<LLMWorkerLoadCapacityConfig> {
    let component_name = std::env::var("TRD_COUNT_SCRAPE_COMPONENT")?;
    if component_name.is_empty() {
        return Err(error!("TRD_COUNT_SCRAPE_COMPONENT is not set"));
    }

    let endpoint_name = std::env::var("TRD_COUNT_SCRAPE_ENDPOINT")?;
    if endpoint_name.is_empty() {
        return Err(error!("TRD_COUNT_SCRAPE_ENDPOINT is not set"));
    }

    Ok(LLMWorkerLoadCapacityConfig {
        component_name,
        endpoint_name,
    })
}

// we will scrape the service_name and extract the endpoint_name metrics
// we will bcast them as {namespace}.events.l2c.{service_name}.{endpoint_name}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMWorkerLoadCapacityConfig {
    component_name: String,
    endpoint_name: String,
}

/// LLM Worker Load Capacity Metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMWorkerLoadCapacity {
    pub requests_active_slots: u32,
    pub requests_total_slots: u32,
    pub kv_blocks_active: u32,
    pub kv_blocks_total: u32,
}

fn main() -> Result<()> {
    logging::init();
    let worker = Worker::from_settings()?;
    worker.execute(app)
}

// TODO - refactor much of this back into the library
async fn app(runtime: Runtime) -> Result<()> {
    // we will start by assuming that there is no oscar and no planner
    // to that end, we will use an env to get a singular config for scraping a single backend
    let config = get_config()?;

    let drt = DistributedRuntime::from_settings(runtime.clone()).await?;

    // todo move to distributed and standardize and move into file/env/cli config
    let namespace = std::env::var("TRD_NAMESPACE").unwrap_or("default".to_string());

    let namespace = drt.namespace(namespace)?;
    let component = namespace.component("count")?;

    // there should only be one count
    // check {component.etcd_path()}/instance for existing instances
    let key = format!("{}/instance", component.etcd_path());
    drt.etcd_client()
        .kv_create(
            key,
            serde_json::to_vec_pretty(&config)?,
            Some(drt.primary_lease().id()),
        )
        .await
        .context("Unable to create unique instance of Count; possibly one already exists")?;

    let target = namespace.component(&config.component_name)?;
    let target_endpoint = target.endpoint(&config.endpoint_name);

    let service_name = target.service_name();
    let service_subject = target_endpoint.subject();

    log::debug!("Scraping service {service_name} and filtering on subject {service_subject}");
    let token = drt.primary_lease().child_token();

    let address = format!("{}.{}", config.component_name, config.endpoint_name,);
    let event_name = format!("l2c.{}", address);

    loop {
        // TODO - make this configurable
        let next = Instant::now() + Duration::from_secs(2);

        // collect stats from each backend
        let stream = target.scrape_stats(Duration::from_secs(1)).await?;

        // filter the stats by the service subject
        let endpoints = stream
            .into_endpoints()
            .filter(|e| e.subject.starts_with(&service_subject))
            .collect::<Vec<_>>();

        // extract the custom data from the stats and try to decode it as LLMWorkerLoadCapacity
        let metrics = endpoints
            .iter()
            .filter_map(|e| match e.data.clone() {
                Some(metrics) => metrics.decode::<LLMWorkerLoadCapacity>().ok(),
                None => None,
            })
            .collect::<Vec<_>>();

        // parse the endpoint ids
        // the ids are the last part of the subject in hexadecimal
        // form a list of tuples (kv_blocks_total - kv_blocks_active, requests_total_slots - requests_active_slots, id)
        // this tuple represent the remaining capacity of each endpoint
        let capacity_with_ids = metrics
            .iter()
            .zip(endpoints.iter())
            .filter_map(|(m, e)| {
                e.id().ok().map(|id| {
                    (
                        m.kv_blocks_total - m.kv_blocks_active,
                        m.requests_total_slots - m.requests_active_slots,
                        id,
                    )
                })
            })
            .collect::<Vec<_>>();

        // compute mean / std of LLMWorkerLoadCapacity
        let load_values: Vec<f64> = metrics.iter().map(|x| x.kv_blocks_active as f64).collect();
        let load_avg = load_values.iter().sum::<f64>() / load_values.len() as f64;
        let variance = load_values
            .iter()
            .map(|&x| (x - load_avg).powi(2))
            .sum::<f64>()
            / load_values.len() as f64;
        let load_std = variance.sqrt();

        let processed = ProcessedEndpoints {
            capacity_with_ids,
            load_avg,
            load_std,
            address: address.clone(),
        };

        // publish using the namespace event plane
        namespace.publish(&event_name, &processed).await?;

        // wait until cancelled or the next tick
        match tokio::time::timeout_at(next, token.cancelled()).await {
            Ok(_) => break,
            Err(_) => {
                // timeout, we continue
                continue;
            }
        }
    }

    Ok(())
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessedEndpoints {
    /// (kv_blocks_total - kv_blocks_active, requests_total_slots - requests_active_slots, id)
    pub capacity_with_ids: Vec<(u32, u32, i64)>,

    /// kv_blocks_active average
    pub load_avg: f64,

    /// kv_blocks_active standard deviation
    pub load_std: f64,

    /// {component}.{endpoint}
    pub address: String,
}