Commit d0d35a9e authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: http + llmctl (#181)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 4e6f3fef
......@@ -41,6 +41,56 @@ dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
[[package]]
name = "anstyle-parse"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
"once_cell",
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
version = "1.0.95"
......@@ -84,7 +134,7 @@ dependencies = [
"serde_json",
"serde_nanos",
"serde_repr",
"thiserror",
"thiserror 1.0.69",
"time",
"tokio",
"tokio-rustls",
......@@ -144,7 +194,7 @@ dependencies = [
"mio 0.6.23",
"once_cell",
"slab",
"thiserror",
"thiserror 1.0.69",
"zmq",
]
......@@ -176,24 +226,58 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.2.0",
"http-body",
"http-body-util",
"itoa",
"matchit 0.7.3",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper",
"tower 0.5.2",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [
"axum-core 0.5.0",
"bytes",
"form_urlencoded",
"futures-util",
"http",
"http 1.2.0",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"itoa",
"matchit",
"matchit 0.8.4",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
......@@ -205,7 +289,26 @@ dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http 1.2.0",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733"
dependencies = [
"bytes",
"futures-util",
"http 1.2.0",
"http-body",
"http-body-util",
"mime",
......@@ -214,6 +317,7 @@ dependencies = [
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
......@@ -283,6 +387,12 @@ version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "bytecount"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce"
[[package]]
name = "bytemuck"
version = "1.21.0"
......@@ -357,6 +467,52 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "clap"
version = "4.5.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acebd8ad879283633b343856142139f2da2317c96b05b4dd6181c61e2480184"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ba32cbda51c7e1dfd49acc1457ba1a7dec5b64fe360e828acb13ca8dc9c2f9"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "clap_lex"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "colorchoice"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]]
name = "const-oid"
version = "0.9.6"
......@@ -717,7 +873,7 @@ version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0452bcc559431b16f472b7ab86e2f9ccd5f3c2da3795afbd6b773665e047fe"
dependencies = [
"http",
"http 1.2.0",
"prost",
"tokio",
"tokio-stream",
......@@ -932,7 +1088,7 @@ dependencies = [
"fnv",
"futures-core",
"futures-sink",
"http",
"http 1.2.0",
"indexmap 2.7.1",
"slab",
"tokio",
......@@ -965,6 +1121,17 @@ dependencies = [
"triton-distributed",
]
[[package]]
name = "http"
version = "0.2.0"
dependencies = [
"serde",
"serde_json",
"tokio",
"triton-distributed",
"triton-llm",
]
[[package]]
name = "http"
version = "1.2.0"
......@@ -983,7 +1150,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
"http 1.2.0",
]
[[package]]
......@@ -994,7 +1161,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
dependencies = [
"bytes",
"futures-util",
"http",
"http 1.2.0",
"http-body",
"pin-project-lite",
]
......@@ -1027,7 +1194,7 @@ dependencies = [
"futures-channel",
"futures-util",
"h2",
"http",
"http 1.2.0",
"http-body",
"httparse",
"httpdate",
......@@ -1060,7 +1227,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http",
"http 1.2.0",
"http-body",
"hyper",
"pin-project-lite",
......@@ -1273,6 +1440,12 @@ dependencies = [
"libc",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.13.0"
......@@ -1351,6 +1524,20 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llmctl"
version = "0.2.0"
dependencies = [
"clap",
"serde",
"serde_json",
"tabled",
"tokio",
"tracing",
"triton-distributed",
"triton-llm",
]
[[package]]
name = "local-ip-address"
version = "0.6.3"
......@@ -1359,7 +1546,7 @@ checksum = "3669cf5561f8d27e8fc84cc15e58350e70f557d4d65f70e3154e54cd2f8e1782"
dependencies = [
"libc",
"neli",
"thiserror",
"thiserror 1.0.69",
"windows-sys 0.59.0",
]
......@@ -1394,6 +1581,12 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "memchr"
version = "2.7.4"
......@@ -1507,7 +1700,7 @@ checksum = "4abdf1789932b85dc39446e27f45a1064a30f9e19a2b872b1d09bd59283f85f3"
dependencies = [
"rand",
"serde",
"thiserror",
"thiserror 1.0.69",
]
[[package]]
......@@ -1607,6 +1800,17 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "papergrid"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b915f831b85d984193fdc3d3611505871dc139b2534530fa01c1a6a6707b6723"
dependencies = [
"bytecount",
"fnv",
"unicode-width",
]
[[package]]
name = "parking_lot"
version = "0.12.3"
......@@ -1813,7 +2017,7 @@ dependencies = [
"memchr",
"parking_lot",
"protobuf",
"thiserror",
"thiserror 1.0.69",
]
[[package]]
......@@ -2213,6 +2417,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.19"
......@@ -2233,6 +2447,18 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd"
dependencies = [
"form_urlencoded",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "sha2"
version = "0.10.8"
......@@ -2401,6 +2627,29 @@ dependencies = [
"version-compare",
]
[[package]]
name = "tabled"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121d8171ee5687a4978d1b244f7d99c43e7385a272185a2f1e1fa4dc0979d444"
dependencies = [
"papergrid",
"tabled_derive",
]
[[package]]
name = "tabled_derive"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52d9946811baad81710ec921809e2af67ad77719418673b2a3794932d57b7538"
dependencies = [
"heck",
"proc-macro-error2",
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "target-lexicon"
version = "0.12.16"
......@@ -2427,7 +2676,16 @@ version = "1.0.69"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
dependencies = [
"thiserror-impl",
"thiserror-impl 1.0.69",
]
[[package]]
name = "thiserror"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
dependencies = [
"thiserror-impl 2.0.11",
]
[[package]]
......@@ -2441,6 +2699,17 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "thiserror-impl"
version = "2.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "thread_local"
version = "1.1.8"
......@@ -2567,7 +2836,7 @@ dependencies = [
"bytes",
"futures-core",
"futures-sink",
"http",
"http 1.2.0",
"httparse",
"rand",
"ring",
......@@ -2620,11 +2889,11 @@ checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
dependencies = [
"async-stream",
"async-trait",
"axum",
"axum 0.7.9",
"base64",
"bytes",
"h2",
"http",
"http 1.2.0",
"http-body",
"http-body-util",
"hyper",
......@@ -2686,8 +2955,10 @@ dependencies = [
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
......@@ -2708,6 +2979,7 @@ version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
......@@ -2810,7 +3082,7 @@ dependencies = [
"serde",
"serde_json",
"socket2",
"thiserror",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
......@@ -2821,6 +3093,33 @@ dependencies = [
"xxhash-rust",
]
[[package]]
name = "triton-llm"
version = "0.2.0"
dependencies = [
"anyhow",
"async-stream",
"async-trait",
"axum 0.8.1",
"bytes",
"chrono",
"derive_builder",
"futures",
"prometheus",
"regex",
"serde",
"serde_json",
"thiserror 2.0.11",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"triton-distributed",
"unicode-segmentation",
"uuid",
"validator",
]
[[package]]
name = "try-lock"
version = "0.2.5"
......@@ -2859,6 +3158,18 @@ version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd"
[[package]]
name = "untrusted"
version = "0.9.0"
......@@ -2888,6 +3199,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.13.1"
......
......@@ -16,6 +16,8 @@
[workspace]
members = [
"hello_world",
"http",
"llmctl",
]
resolver = "2"
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "http"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
[dependencies]
triton-distributed = { workspace = true}
triton-llm = { workspace = true}
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use triton_distributed::{logging, DistributedRuntime, Result, Runtime, Worker};
use triton_llm::http::service::{
discovery::{model_watcher, ModelWatchState},
service_v2::HttpService,
};
fn main() -> Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
// create the http service and acquire the model manager
let http_service = HttpService::builder().port(9992).build()?;
let manager = http_service.model_manager().clone();
// todo - use the IntoComponent trait to register the component
// todo - start a service
// todo - we want the service to create an entry and register component definition
// todo - the component definition should be the type of component and it's config
// in this example we will have an HttpServiceComponentDefinition object which will be
// written to etcd
// the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition
let component = distributed.namespace("public")?.component("http")?;
let etcd_root = component.etcd_path();
let etcd_path = format!("{}/models/chat/", etcd_root);
let state = Arc::new(ModelWatchState {
prefix: etcd_path.clone(),
manager,
drt: distributed.clone(),
});
let etcd_client = distributed.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(model_watcher(state, receiver));
http_service.run(runtime.child_token()).await
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[package]
name = "llmctl"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
[dependencies]
triton-distributed = { workspace = true}
triton-llm = { workspace = true}
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
tokio = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
tabled = "0.18"
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::{Parser, Subcommand};
use tracing as log;
use triton_distributed::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker,
};
use triton_llm::http::service::discovery::ModelEntry;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Namespace to operate in
#[arg(short = 'n', long)]
namespace: Option<String>,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// HTTP service related commands
Http {
#[command(subcommand)]
command: HttpCommands,
},
}
#[derive(Subcommand)]
enum HttpCommands {
/// Add a chat model
Add {
/// Specifies we're adding a chat model
#[arg(value_name = "chat-model")]
chat_model: String,
/// Model name (e.g. foo/v1)
model_name: String,
/// Endpoint name (format: component.endpoint or namespace.component.endpoint)
endpoint_name: String,
},
/// List chat models
List {
/// Specifies we're listing chat models
#[arg(value_name = "chat-model", value_parser = parse_chat_model)]
chat_model: String,
},
/// Remove a chat model
Remove {
/// Specifies we're removing a chat model
#[arg(value_name = "chat-model")]
chat_model: String,
/// Name of the model to remove
name: String,
},
}
fn parse_chat_model(s: &str) -> Result<String> {
match s {
"chat-model" | "chat-models" => Ok(s.to_string()),
_ => raise!("Expected 'chat-model' or 'chat-models'"),
}
}
fn main() -> Result<()> {
logging::init();
let cli = Cli::parse();
// Default namespace to "public" if not specified
let namespace = cli.namespace.unwrap_or_else(|| "public".to_string());
let worker = Worker::from_settings()?;
worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
}
async fn handle_command(runtime: Runtime, namespace: String, command: Commands) -> Result<()> {
let settings = DistributedConfig::for_cli();
let distributed = DistributedRuntime::new(runtime, settings).await?;
match command {
Commands::Http { command } => {
match command {
HttpCommands::Add {
chat_model: _,
model_name,
endpoint_name,
} => {
log::debug!(
"Adding model {} with endpoint {}",
model_name,
endpoint_name
);
// parse endpoint
// split by '.' must have 2, can have 3 parts, any more or less is an error
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 || parts.len() > 3 {
raise!("Invalid endpoint name: {}", endpoint_name);
}
// if 3 parts, then it's namespace.component.endpoint
// if 2 parts, then it's model_name.component.endpoint
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry {
name: model_name.clone(),
endpoint,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!("{}/models/chat/{}", component.etcd_path(), model_name);
let etcd_client = distributed.etcd_client();
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Model {} added to namespace {}", model_name, namespace);
}
HttpCommands::List { chat_model: _ } => {
let component = distributed.namespace(&namespace)?.component("http")?;
// todo - make this part of the http discovery service object
let prefix = format!("{}/models/chat/", component.etcd_path());
// get the kvs from etcd
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
use tabled::Tabled;
#[derive(Tabled)]
struct ModelRow {
#[tabled(rename = "MODEL NAME")]
name: String,
#[tabled(rename = "NAMESPACE")]
namespace: String,
#[tabled(rename = "COMPONENT")]
component: String,
#[tabled(rename = "ENDPOINT")]
endpoint: String,
}
// parse the keys
let mut models = Vec::new();
for kv in kvs {
match (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
(Ok(key), Ok(model)) => {
models.push(ModelRow {
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
(Err(e), _) => {
log::debug!("Error parsing key: {}", e);
}
(_, Err(e)) => {
log::debug!("Error parsing value: {}", e);
}
}
}
if models.is_empty() {
println!("No chat models found in namespace {}", namespace);
} else {
let table = tabled::Table::new(models);
println!("Listing chat models in namespace {}", namespace);
println!("{}", table);
}
}
HttpCommands::Remove {
chat_model: _,
name,
} => {
// TODO: Implement remove logic
log::debug!("Removing model {}", name);
let component = distributed.namespace(&namespace)?.component("http")?;
// todo - make this part of the http discovery service object
let prefix = format!("{}/models/chat/{name}", component.etcd_path());
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!("Model {} removed from namespace {}", name, namespace);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
}
}
}
}
}
}
Ok(())
}
......@@ -14,19 +14,3 @@
// limitations under the License.
pub mod service;
use serde::{Deserialize, Serialize};
use triton_distributed::protocols;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// This will be used to identify the model in the HTTP service and the value used in an
/// an [OAI ChatRequest][crate::protocols::openai::chat_completions::ChatCompletionRequest].
name: String,
/// Component of the endpoint.
endpoint: protocols::Endpoint,
}
......@@ -32,6 +32,7 @@
mod openai;
pub mod discovery;
pub mod error;
pub mod metrics;
pub mod service_v2;
......
......@@ -13,14 +13,126 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use tracing as log;
use triton_distributed::{
component::ComponentEndpointInfo, transports::etcd::WatchEvent, DistributedRuntime, Result,
Runtime, Worker,
protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result,
};
use triton_llm::http::service::{
service_v2::{HttpService, HttpServiceConfig},
ModelManager,
use super::ModelManager;
use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
};
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// This will be used to identify the model in the HTTP service and the value used in an
/// an [OAI ChatRequest][crate::protocols::openai::chat_completions::ChatCompletionRequest].
pub name: String,
/// Component of the endpoint.
pub endpoint: protocols::Endpoint,
}
pub struct ModelWatchState {
pub prefix: String,
pub manager: ModelManager,
pub drt: DistributedRuntime,
}
pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) {
log::debug!("model watcher started");
let mut events_rx = events_rx;
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
Ok(model_name) => {
log::info!("added chat model: {}", model_name);
}
Err(e) => {
log::error!("error adding chat model: {}", e);
// log::warn!(
// "deleting offending key: {}",
// kv.key_str().unwrap_or_default()
// );
// if let Err(e) = kv_client.delete(kv.key(), None).await {
// log::error!("failed to delete offending key: {}", e);
// }
}
},
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => {
log::info!("removed chat model: {}", model_name);
}
Err(e) => {
log::error!("error removing chat model: {}", e);
}
},
}
}
log::debug!("model watcher stopped");
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
log::debug!("removing model");
let key = kv.key_str()?;
log::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix);
state.manager.remove_chat_completions_model(model_name)?;
Ok(model_name.to_string())
}
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
log::debug!("adding model");
let key = kv.key_str()?;
log::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix);
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
// this means there is an entry in etcd that breaks the contract that the key
// in the models path must match the model name in the entry.
if model_entry.name != model_name {
raise!(
"model name mismatch: {} != {}",
model_entry.name,
model_name
);
}
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
let client = Arc::new(client);
state
.manager
.add_chat_completions_model(model_name, client)?;
Ok(model_name.to_string())
}
......@@ -141,4 +141,15 @@ impl DistributedConfig {
nats_config: nats::ClientOptions::default(),
}
}
pub fn for_cli() -> DistributedConfig {
let mut config = DistributedConfig {
etcd_config: etcd::ClientOptions::default(),
nats_config: nats::ClientOptions::default(),
};
config.etcd_config.attach_lease = false;
config
}
}
......@@ -20,7 +20,9 @@
use std::sync::{Arc, Mutex};
pub use anyhow::{anyhow as error, Context as ErrorContext, Error, Ok as OK, Result};
pub use anyhow::{
anyhow as error, bail as raise, Context as ErrorContext, Error, Ok as OK, Result,
};
use async_once_cell::OnceCell;
......@@ -33,6 +35,7 @@ pub mod engine;
pub mod logging;
pub mod pipeline;
pub mod protocols;
pub mod runnable;
pub mod runtime;
pub mod service;
pub mod transports;
......
......@@ -56,7 +56,7 @@ use tracing_subscriber::{filter::Directive, fmt};
const FILTER_ENV: &str = "TRD_LOG";
/// Default log level
const DEFAULT_FILTER_LEVEL: &str = "info";
const DEFAULT_FILTER_LEVEL: &str = "error";
/// ENV used to set the path to the logging configuration file
const CONFIG_PATH_ENV: &str = "TRD_LOGGING_CONFIG_PATH";
......@@ -73,7 +73,13 @@ impl Default for LoggingConfig {
fn default() -> Self {
LoggingConfig {
log_level: DEFAULT_FILTER_LEVEL.to_string(),
log_filters: HashMap::new(),
log_filters: HashMap::from([
("h2".to_string(), "error".to_string()),
("tower".to_string(), "error".to_string()),
("hyper_util".to_string(), "error".to_string()),
("neli".to_string(), "error".to_string()),
("async_nats".to_string(), "error".to_string()),
]),
}
}
}
......
......@@ -35,10 +35,6 @@ pub struct Endpoint {
/// Namespace of the component.
pub namespace: String,
/// Optional lease id for the endpoint.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lease: Option<LeaseId>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
......@@ -82,7 +78,6 @@ mod tests {
name: "test_endpoint".to_string(),
component: "test_component".to_string(),
namespace: "test_namespace".to_string(),
lease: None,
};
assert_eq!(endpoint.name, "test_endpoint");
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Runnable Module.
//!
//! This module provides a way to run a task in a runtime.
//!
use std::{
pin::Pin,
task::{Context, Poll},
};
pub use crate::{Error, Result};
pub use async_trait::async_trait;
pub use tokio::task::JoinHandle;
pub use tokio_util::sync::CancellationToken;
#[async_trait]
pub trait ExecutionHandle {
fn is_finished(&self) -> bool;
fn is_cancelled(&self) -> bool;
fn cancel(&self);
fn cancellation_token(&self) -> CancellationToken;
fn handle(self) -> JoinHandle<Result<()>>;
}
......@@ -20,13 +20,12 @@ use derive_builder::Builder;
use derive_getters::Dissolve;
use futures::StreamExt;
use tokio::sync::mpsc;
use tracing as log;
use validator::Validate;
use etcd_client::{
Compare, CompareOp, GetOptions, KeyValue, PutOptions, Txn, TxnOp, WatchOptions, Watcher,
};
use etcd_client::{Compare, CompareOp, GetOptions, PutOptions, Txn, TxnOp, WatchOptions, Watcher};
pub use etcd_client::{ConnectOptions, LeaseClient};
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
mod lease;
use lease::*;
......@@ -192,6 +191,9 @@ impl Client {
.ok_or(error!("missing header; unable to get revision"))?
.revision();
log::trace!("start_revision: {}", start_revision);
let start_revision = start_revision + 1;
let (watcher, mut watch_stream) = watch_client
.watch(
prefix.as_ref(),
......@@ -204,6 +206,7 @@ impl Client {
.await?;
let kvs = get_response.take_kvs();
log::trace!("initial kv count: {:?}", kvs.len());
let (tx, rx) = mpsc::channel(32);
......@@ -263,14 +266,14 @@ pub enum WatchEvent {
#[derive(Debug, Clone, Builder, Validate)]
pub struct ClientOptions {
#[validate(length(min = 1))]
etcd_url: Vec<String>,
pub etcd_url: Vec<String>,
#[builder(default)]
etcd_connect_options: Option<ConnectOptions>,
pub etcd_connect_options: Option<ConnectOptions>,
/// If true, the client will attach a lease to the primary [`CancellationToken`].
#[builder(default = "true")]
attach_lease: bool,
pub attach_lease: bool,
}
impl Default for ClientOptions {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment