Commit 494d5625 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

fix: Fix stream::until_deadline bug and improve metric examples (#280)


Co-authored-by: default avatarRyan Olson <rolson@nvidia.com>
parent cec8248d
......@@ -41,6 +41,56 @@ dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
[[package]]
name = "anstyle-parse"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
"once_cell",
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
version = "1.0.96"
......@@ -441,6 +491,52 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "clap"
version = "4.5.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "clap_lex"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "colorchoice"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "console"
version = "0.15.10"
......@@ -496,8 +592,10 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
name = "count"
version = "0.1.0"
dependencies = [
"clap",
"serde",
"serde_json",
"thiserror 1.0.69",
"tokio",
"tracing",
"triton-distributed-llm",
......@@ -1507,6 +1605,12 @@ dependencies = [
"libc",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
......@@ -3723,6 +3827,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.14.0"
......
......@@ -27,7 +27,9 @@ triton-distributed-llm = { path = "../../../lib/llm" }
# workspace - todo
# crates.io
clap = { version = "4.5", features = ["derive", "env"] }
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" }
tokio = { version = "1", features = ["full"] }
tracing = { version = "0.1" }
thiserror = "1.0"
# Count
## Quickstart
To start `count`, simply point it at the namespace/component/endpoint trio that
you're interested in observing metrics from. This will scrape statistics from
the services associated with that endpoint, do some postprocessing on them,
and then publish an event with the postprocessed data.
```bash
# For more details, try TRD_LOG=debug
TRD_LOG=info cargo run -- --namespace triton-init --component backend --endpoint generate
# 2025-02-26T18:45:05.467026Z INFO count: Creating unique instance of Count at triton-init/components/count/instance
# 2025-02-26T18:45:05.472146Z INFO count: Scraping service triton_init_backend_720278f8 and filtering on subject triton_init_backend_720278f8.generate
# ...
```
With no matching endpoints running, you should see warnings in the logs:
```bash
2025-02-26T18:45:06.474161Z WARN count: No endpoints found matching subject triton_init_backend_720278f8.generate
```
But after starting a matching endpoint, such as the
[service_metrics example](examples/rust/service_metrics/src/bin/server.rs),
you should see these warnings go away since the endpoint will automatically
get discovered.
Whether there are matching endpoints found or not, `count` will publish events, for example:
```
2025-02-26T18:45:46.501874Z INFO count: Publishing event l2c.backend.generate on Namespace { name: "triton-init" } with ProcessedEndpoints { capacity_with_ids: [], load_avg: NaN, load_std: NaN, address: "backend.generate" }
```
However, the events may not be very useful until there are corresponding stats found from endpoints for processing.
......@@ -23,6 +23,7 @@
//! - Request Slots: [Active, Total]
//! - KV Cache Blocks: [Active, Total]
use clap::Parser;
use serde::{Deserialize, Serialize};
use triton_distributed_runtime::{
......@@ -32,26 +33,43 @@ use triton_distributed_runtime::{
DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};
use tracing as log;
/// CLI arguments for the count application
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Component to scrape metrics from
#[arg(long)]
component: String,
/// Endpoint to scrape metrics from
#[arg(long)]
endpoint: String,
/// Namespace to operate in
#[arg(long, env = "TRD_NAMESPACE", default_value = "triton-init")]
namespace: String,
/// Polling interval in seconds (minimum 1 second)
#[arg(long, default_value = "2")]
poll_interval: u64,
}
// enum MetricTypes {
// LLMWorkerLoadCapacity(LLMWorkerLoadCapacityConfig),
// }
fn get_config(args: &Args) -> Result<LLMWorkerLoadCapacityConfig> {
if args.component.is_empty() {
return Err(error!("Component name cannot be empty"));
}
fn get_config() -> Result<LLMWorkerLoadCapacityConfig> {
let component_name = std::env::var("TRD_COUNT_SCRAPE_COMPONENT")?;
if component_name.is_empty() {
return Err(error!("TRD_COUNT_SCRAPE_COMPONENT is not set"));
if args.endpoint.is_empty() {
return Err(error!("Endpoint name cannot be empty"));
}
let endpoint_name = std::env::var("TRD_COUNT_SCRAPE_ENDPOINT")?;
if endpoint_name.is_empty() {
return Err(error!("TRD_COUNT_SCRAPE_ENDPOINT is not set"));
if args.poll_interval < 1 {
return Err(error!("Polling interval must be at least 1 second"));
}
Ok(LLMWorkerLoadCapacityConfig {
component_name,
endpoint_name,
component_name: args.component.clone(),
endpoint_name: args.endpoint.clone(),
})
}
......@@ -74,27 +92,27 @@ pub struct LLMWorkerLoadCapacity {
fn main() -> Result<()> {
logging::init();
let args = Args::parse();
let worker = Worker::from_settings()?;
worker.execute(app)
worker.execute(|runtime| app(runtime, args))
}
// TODO - refactor much of this back into the library
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime, args: Args) -> Result<()> {
// we will start by assuming that there is no oscar and no planner
// to that end, we will use an env to get a singular config for scraping a single backend
let config = get_config()?;
// to that end, we will use CLI args to get a singular config for scraping a single backend
let config = get_config(&args)?;
tracing::info!("Config: {config:?}");
let drt = DistributedRuntime::from_settings(runtime.clone()).await?;
// todo move to distributed and standardize and move into file/env/cli config
let namespace = std::env::var("TRD_NAMESPACE").unwrap_or("default".to_string());
let namespace = drt.namespace(namespace)?;
let namespace = drt.namespace(args.namespace)?;
let component = namespace.component("count")?;
// there should only be one count
// check {component.etcd_path()}/instance for existing instances
let key = format!("{}/instance", component.etcd_path());
tracing::info!("Creating unique instance of Count at {key}");
drt.etcd_client()
.kv_create(
key,
......@@ -109,19 +127,19 @@ async fn app(runtime: Runtime) -> Result<()> {
let service_name = target.service_name();
let service_subject = target_endpoint.subject();
tracing::info!("Scraping service {service_name} and filtering on subject {service_subject}");
log::debug!("Scraping service {service_name} and filtering on subject {service_subject}");
let token = drt.primary_lease().child_token();
let address = format!("{}.{}", config.component_name, config.endpoint_name,);
let event_name = format!("l2c.{}", address);
loop {
// TODO - make this configurable
let next = Instant::now() + Duration::from_secs(2);
let next = Instant::now() + Duration::from_secs(args.poll_interval);
// collect stats from each backend
let stream = target.scrape_stats(Duration::from_secs(1)).await?;
tracing::debug!("Scraped Stats Stream: {stream:?}");
// filter the stats by the service subject
let endpoints = stream
......@@ -129,6 +147,11 @@ async fn app(runtime: Runtime) -> Result<()> {
.filter(|e| e.subject.starts_with(&service_subject))
.collect::<Vec<_>>();
tracing::debug!("Endpoints: {endpoints:?}");
if endpoints.is_empty() {
tracing::warn!("No endpoints found matching subject {}", service_subject);
}
// extract the custom data from the stats and try to decode it as LLMWorkerLoadCapacity
let metrics = endpoints
.iter()
......@@ -137,6 +160,7 @@ async fn app(runtime: Runtime) -> Result<()> {
None => None,
})
.collect::<Vec<_>>();
tracing::debug!("Metrics: {metrics:?}");
// parse the endpoint ids
// the ids are the last part of the subject in hexadecimal
......@@ -174,6 +198,9 @@ async fn app(runtime: Runtime) -> Result<()> {
};
// publish using the namespace event plane
tracing::info!(
"Publishing event {event_name} on namespace {namespace:?} with {processed:?}"
);
namespace.publish(&event_name, &processed).await?;
// wait until cancelled or the next tick
......@@ -203,3 +230,20 @@ pub struct ProcessedEndpoints {
/// {component}.{endpoint}
pub address: String,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_namespace_from_env() {
env::set_var("TRD_NAMESPACE", "test-namespace");
// Parse args with no explicit namespace
let args = Args::parse_from(["count", "--component", "comp", "--endpoint", "end"]);
// Verify namespace was taken from environment variable
assert_eq!(args.namespace, "test-namespace");
}
}
......@@ -2993,6 +2993,8 @@ name = "service_metrics"
version = "0.2.0"
dependencies = [
"futures",
"serde",
"serde_json",
"tokio",
"triton-distributed-runtime",
]
......
......@@ -27,4 +27,6 @@ triton-distributed-runtime = { workspace = true }
# third-party
futures = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
......@@ -3,10 +3,16 @@
This example extends the hello_world example by calling the `scrape_service` method
with the service name for the request response the client just issued a request.
```bash
TRD_LOG=debug cargo run --bin server
```
The client can now observe some basic statistics about each instance of the service
begin hosted.
If you start two copies of the server, you will see two entries being emitted.
```bash
TRD_LOG=info cargo run --bin client
```
## Example Output
```
......@@ -21,5 +27,13 @@ Annotated { data: Some("o"), id: None, event: None, comment: None }
Annotated { data: Some("r"), id: None, event: None, comment: None }
Annotated { data: Some("l"), id: None, event: None, comment: None }
Annotated { data: Some("d"), id: None, event: None, comment: None }
ServiceSet { services: [ServiceInfo { name: "triton_init_backend_720278f8", id: "j6n37goJog3df2PMkQK1Ry", version: "0.0.1", started: "2025-02-18T20:51:01.40830026Z", endpoints: [EndpointInfo { name: "triton_init_backend_720278f8-generate-694d94fc30dbb562", subject: "triton_init_backend_720278f8.generate-694d94fc30dbb562", data: Some(Metrics(Object {"average_processing_time": Number(67387), "last_error": String(""), "num_errors": Number(0), "num_requests": Number(1), "processing_time": Number(67387), "queue_group": String("q")})) }] }] }
ServiceSet { services: [ServiceInfo { name: "triton_init_backend_720278f8", id: "eOHMc4ndRw8s5flv4WOZx7", version: "0.0.1", started: "2025-02-26T18:54:04.917294605Z", endpoints: [EndpointInfo { name: "triton_init_backend_720278f8-generate-694d951a80e06abf", subject: "triton_init_backend_720278f8.generate-694d951a80e06abf", data: Some(Metrics(Object {"average_processing_time": Number(53662), "data": Object {"val": Number(10)}, "last_error": String(""), "num_errors": Number(0), "num_requests": Number(2), "processing_time": Number(107325), "queue_group": String("q")})) }] }] }
```
Note the following stats in the output demonstrate the custom
`stats_handler` attached to the service in `server.rs` is being invoked:
```
data: Some(Metrics(Object {..., "data": Object {"val": Number(10)}, ...)
```
If you start two copies of the server, you will see two entries being emitted.
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use service_metrics::DEFAULT_NAMESPACE;
use service_metrics::{MyStats, DEFAULT_NAMESPACE};
use std::sync::Arc;
use triton_distributed_runtime::{
......@@ -71,6 +71,11 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.service_builder()
// Dummy stats handler to demonstrate how to attach a custom stats handler
.stats_handler(Some(Box::new(|_name, _stats| {
let stats = MyStats { val: 10 };
serde_json::to_value(stats).unwrap()
})))
.create()
.await?
.endpoint("generate")
......
......@@ -13,4 +13,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
pub const DEFAULT_NAMESPACE: &str = "triton-init";
#[derive(Serialize, Deserialize)]
// Dummy Stats object to demonstrate how to attach a custom stats handler
pub struct MyStats {
pub val: u32,
}
......@@ -110,7 +110,7 @@ impl ServiceClient {
}
let deadline = tokio::time::Instant::now() + duration;
let services = stream::until_deadline(sub, deadline)
let services: Vec<ServiceInfo> = stream::until_deadline(sub, deadline)
.map(|message| serde_json::from_slice::<ServiceInfo>(&message.payload))
.filter_map(|info| async move {
match info {
......
......@@ -24,7 +24,7 @@ use tokio::time::{self, sleep_until, Duration, Instant, Sleep};
pub struct DeadlineStream<S> {
stream: S,
deadline: Instant,
sleep: Pin<Box<Sleep>>,
}
impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
......@@ -32,7 +32,7 @@ impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// Check if we've passed the deadline
if Instant::now() >= self.deadline {
if Pin::new(&mut self.sleep).poll(cx).is_ready() {
// The deadline expired; end the stream now
return Poll::Ready(None);
}
......@@ -43,7 +43,11 @@ impl<S: Stream + Unpin> Stream for DeadlineStream<S> {
}
pub fn until_deadline<S: Stream + Unpin>(stream: S, deadline: Instant) -> DeadlineStream<S> {
DeadlineStream { stream, deadline }
DeadlineStream {
stream,
// Set an async task that sleeps until deadline and wakes up to cancel the stream
sleep: Box::pin(sleep_until(deadline)),
}
}
#[cfg(test)]
......@@ -53,9 +57,9 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_until_deadline() {
let stream = stream::iter(vec![100, 100, 200]);
// Helper function to run the deadline stream test with given parameters
async fn run_deadline_test(sleep_times_ms: Vec<u64>, deadline_ms: u64) -> Vec<u64> {
let stream = stream::iter(sleep_times_ms);
let stream = stream.then(|x| {
let sleep = time::sleep(Duration::from_millis(x));
async move {
......@@ -63,13 +67,39 @@ mod tests {
x
}
});
let deadline = Instant::now() + Duration::from_millis(300);
let deadline = Instant::now() + Duration::from_millis(deadline_ms);
let mut result = Vec::new();
pin!(stream);
let mut stream = until_deadline(stream, deadline);
while let Some(x) = stream.next().await {
result.push(x);
}
result
}
#[tokio::test]
async fn test_deadline_exceeded() {
// The sum of the sleep times should exceed the deadline
let sleep_times_ms = vec![100, 100, 200, 50];
let deadline_ms = 300;
let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
// Since deadline is exceeded, only the items before deadline should be returned
assert_eq!(result, vec![100, 100]);
}
#[tokio::test]
async fn test_complete_before_deadline() {
// The sum of the sleep times should be less than the deadline
let sleep_times_ms = vec![100, 50, 50];
let deadline_ms = 300;
let result = run_deadline_test(sleep_times_ms, deadline_ms).await;
// Since deadline is not exceeded, all items should be returned from stream
assert_eq!(result, vec![100, 50, 50]);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment