Commit b760c569 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: Add completion endpoint to http server and llmctl (#230)


Co-authored-by: default avataraflowers <aflowers@nvidia.com>
parent 113f4d91
...@@ -2867,6 +2867,28 @@ version = "0.11.1" ...@@ -2867,6 +2867,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -3410,6 +3432,7 @@ dependencies = [ ...@@ -3410,6 +3432,7 @@ dependencies = [
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
"tokio", "tokio",
......
...@@ -3016,6 +3016,28 @@ version = "0.11.1" ...@@ -3016,6 +3016,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -3582,6 +3604,7 @@ dependencies = [ ...@@ -3582,6 +3604,7 @@ dependencies = [
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
"tokio", "tokio",
......
...@@ -16,11 +16,16 @@ ...@@ -16,11 +16,16 @@
use clap::Parser; use clap::Parser;
use std::sync::Arc; use std::sync::Arc;
use triton_distributed_llm::http::service::{ use triton_distributed_llm::{
discovery::{model_watcher, ModelWatchState}, http::service::{
service_v2::HttpService, discovery::{model_watcher, ModelWatchState},
service_v2::HttpService,
},
model_type::ModelType,
};
use triton_distributed_runtime::{
logging, transports::etcd::PrefixWatcher, DistributedRuntime, Result, Runtime, Worker,
}; };
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
...@@ -50,10 +55,8 @@ fn main() -> Result<()> { ...@@ -50,10 +55,8 @@ fn main() -> Result<()> {
async fn app(runtime: Runtime) -> Result<()> { async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let args = Args::parse(); let args = Args::parse();
// create the http service and acquire the model manager
let http_service = HttpService::builder() let http_service = HttpService::builder()
.port(args.port) .port(args.port)
.host(args.host) .host(args.host)
...@@ -71,21 +74,29 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -71,21 +74,29 @@ async fn app(runtime: Runtime) -> Result<()> {
let component = distributed let component = distributed
.namespace(&args.namespace)? .namespace(&args.namespace)?
.component(&args.component)?; .component(&args.component)?;
let etcd_root = component.etcd_path(); let etcd_root = component.etcd_path();
let etcd_path = format!("{}/models/chat/", etcd_root);
let state = Arc::new(ModelWatchState { // Create watchers for all model types
prefix: etcd_path.clone(), let mut watcher_tasks = Vec::new();
manager,
drt: distributed.clone(), for model_type in ModelType::all() {
}); let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str());
let state = Arc::new(ModelWatchState {
prefix: etcd_path.clone(),
model_type,
manager: manager.clone(),
drt: distributed.clone(),
});
let etcd_client = distributed.etcd_client(); let etcd_client = distributed.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?; let models_watcher: PrefixWatcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(model_watcher(state, receiver)); let watcher_task = tokio::spawn(model_watcher(state, receiver));
watcher_tasks.push(watcher_task);
}
// Run the service
http_service.run(runtime.child_token()).await http_service.run(runtime.child_token()).await
} }
...@@ -16,18 +16,96 @@ ...@@ -16,18 +16,96 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing as log; use tracing as log;
use triton_distributed_llm::http::service::discovery::ModelEntry; use triton_distributed_llm::{http::service::discovery::ModelEntry, model_type::ModelType};
use triton_distributed_runtime::{ use triton_distributed_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime, distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker, Result, Runtime, Worker,
}; };
// Macro to define model types and associated commands
macro_rules! define_type_subcommands {
($(($variant:ident, $primary_name:expr, [$($alias:expr),*], $help:expr)),* $(,)?) => {
#[derive(Subcommand)]
enum AddCommands {
$(
#[doc = $help]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant(AddModelArgs),
)*
}
#[derive(Subcommand)]
enum ListCommands {
$(
#[doc = concat!("List ", $primary_name, " models")]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant,
)*
}
#[derive(Subcommand)]
enum RemoveCommands {
$(
#[doc = concat!("Remove ", $primary_name, " model")]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant(RemoveModelArgs),
)*
}
impl AddCommands {
fn into_parts(self) -> (ModelType, String, String) {
match self {
$(Self::$variant(args) => (ModelType::$variant, args.model_name, args.endpoint_name)),*
}
}
}
impl RemoveCommands {
fn into_parts(self) -> (ModelType, String) {
match self {
$(Self::$variant(args) => (ModelType::$variant, args.model_name)),*
}
}
}
impl ListCommands {
fn model_type(&self) -> ModelType {
match self {
$(Self::$variant => ModelType::$variant),*
}
}
}
}
}
define_type_subcommands!(
(
Chat,
"chat",
["chat-model", "chat-models"],
"Add a chat model"
),
(
Completion,
"completion",
["completions", "completion-model"],
"Add a completion model"
),
// Add new model types here:
);
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(
author="NVIDIA",
version="0.2.1",
about="LLMCTL - Control and manage TRD Components",
long_about = None,
disable_help_subcommand = true,
)]
struct Cli { struct Cli {
/// Namespace to operate in /// Public Namespace to operate in
#[arg(short = 'n', long)] #[arg(short = 'n', long)]
namespace: Option<String>, public_namespace: Option<String>,
#[command(subcommand)] #[command(subcommand)]
command: Commands, command: Commands,
...@@ -44,42 +122,41 @@ enum Commands { ...@@ -44,42 +122,41 @@ enum Commands {
#[derive(Subcommand)] #[derive(Subcommand)]
enum HttpCommands { enum HttpCommands {
/// Add a chat model /// Add models
Add { Add {
/// Specifies we're adding a chat model #[command(subcommand)]
#[arg(value_name = "chat-model")] model_type: AddCommands,
chat_model: String,
/// Model name (e.g. foo/v1)
model_name: String,
/// Endpoint name (format: component.endpoint or namespace.component.endpoint)
endpoint_name: String,
}, },
/// List chat models /// List models (all types if no specific type provided)
List { List {
/// Specifies we're listing chat models #[command(subcommand)]
#[arg(value_name = "chat-model", value_parser = parse_chat_model)] model_type: Option<ListCommands>,
chat_model: String,
}, },
/// Remove a chat model /// Remove models
Remove { Remove {
/// Specifies we're removing a chat model #[command(subcommand)]
#[arg(value_name = "chat-model")] model_type: RemoveCommands,
chat_model: String,
/// Name of the model to remove
name: String,
}, },
} }
fn parse_chat_model(s: &str) -> Result<String> { #[derive(Parser)]
match s { struct AddModelArgs {
"chat-model" | "chat-models" => Ok(s.to_string()), /// Model name (e.g. foo/v1)
_ => raise!("Expected 'chat-model' or 'chat-models'"), #[arg(name = "model-name")]
} model_name: String,
/// Endpoint name (format: component.endpoint or namespace.component.endpoint)
#[arg(name = "endpoint-name")]
endpoint_name: String,
}
/// Common fields for removing any model type
#[derive(Parser)]
struct RemoveModelArgs {
/// Name of the model to remove
#[arg(name = "model-name")]
model_name: String,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
...@@ -87,7 +164,7 @@ fn main() -> Result<()> { ...@@ -87,7 +164,7 @@ fn main() -> Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
// Default namespace to "public" if not specified // Default namespace to "public" if not specified
let namespace = cli.namespace.unwrap_or_else(|| "public".to_string()); let namespace = cli.public_namespace.unwrap_or_else(|| "public".to_string());
let worker = Worker::from_settings()?; let worker = Worker::from_settings()?;
worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await }) worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
...@@ -100,136 +177,262 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands) ...@@ -100,136 +177,262 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
match command { match command {
Commands::Http { command } => { Commands::Http { command } => {
match command { match command {
HttpCommands::Add { HttpCommands::Add { model_type } => {
chat_model: _, let (model_type, model_name, endpoint_name) = model_type.into_parts();
model_name, add_model(
endpoint_name, &distributed,
} => { namespace.to_string(),
log::debug!( model_type,
"Adding model {} with endpoint {}",
model_name, model_name,
endpoint_name &endpoint_name,
); )
.await?;
// parse endpoint
// split by '.' must have 2, can have 3 parts, any more or less is an error
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 || parts.len() > 3 {
raise!("Invalid endpoint name: {}", endpoint_name);
}
// if 3 parts, then it's namespace.component.endpoint
// if 2 parts, then it's model_name.component.endpoint
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry {
name: model_name.clone(),
endpoint,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!("{}/models/chat/{}", component.etcd_path(), model_name);
let etcd_client = distributed.etcd_client();
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Model {} added to namespace {}", model_name, namespace);
} }
HttpCommands::List { chat_model: _ } => { HttpCommands::List { model_type } => {
let component = distributed.namespace(&namespace)?.component("http")?; match model_type {
// todo - make this part of the http discovery service object Some(model_type) => {
let prefix = format!("{}/models/chat/", component.etcd_path()); list_models(
&distributed,
// get the kvs from etcd namespace.clone(),
let etcd_client = distributed.etcd_client(); Some(model_type.model_type()),
let kvs = etcd_client.kv_get_prefix(&prefix).await?; )
.await?;
use tabled::Tabled;
#[derive(Tabled)]
struct ModelRow {
#[tabled(rename = "MODEL NAME")]
name: String,
#[tabled(rename = "NAMESPACE")]
namespace: String,
#[tabled(rename = "COMPONENT")]
component: String,
#[tabled(rename = "ENDPOINT")]
endpoint: String,
}
// parse the keys
let mut models = Vec::new();
for kv in kvs {
match (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
(Ok(key), Ok(model)) => {
models.push(ModelRow {
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
(Err(e), _) => {
log::debug!("Error parsing key: {}", e);
}
(_, Err(e)) => {
log::debug!("Error parsing value: {}", e);
}
} }
} None => {
// List all model types
if models.is_empty() { list_models(&distributed, namespace.clone(), None).await?;
println!("No chat models found in namespace {}", namespace);
} else {
let table = tabled::Table::new(models);
println!("Listing chat models in namespace {}", namespace);
println!("{}", table);
}
}
HttpCommands::Remove {
chat_model: _,
name,
} => {
// TODO: Implement remove logic
log::debug!("Removing model {}", name);
let component = distributed.namespace(&namespace)?.component("http")?;
// todo - make this part of the http discovery service object
let prefix = format!("{}/models/chat/{name}", component.etcd_path());
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!("Model {} removed from namespace {}", name, namespace);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
} }
} }
} }
HttpCommands::Remove { model_type } => {
let (model_type, name) = model_type.into_parts();
remove_model(&distributed, namespace.to_string(), model_type, &name).await?;
}
}
}
}
Ok(())
}
// Helper functions to handle the actual operations
async fn add_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
endpoint_name: &str,
) -> Result<()> {
log::debug!(
"Adding model {} with endpoint {}",
model_name,
endpoint_name
);
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
raise!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
raise!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
}
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
println!(
"Using the public namespace: {} for model: {}",
namespace, model_name
);
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry {
name: model_name.to_string(),
endpoint,
model_type,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
model_name
);
let etcd_client = distributed.etcd_client();
// check if model already exists
let kvs = etcd_client.kv_get_prefix(&path).await?;
if !kvs.is_empty() {
println!(
"{} model {} already exists, please remove it before changing the endpoint.",
model_type.as_str(),
model_name,
);
list_single_model(distributed, namespace, model_type, model_name).await?;
} else {
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Added new {} model {}", model_type.as_str(), model_name,);
list_single_model(distributed, namespace, model_type, model_name).await?;
}
Ok(())
}
#[derive(tabled::Tabled)]
struct ModelRow {
#[tabled(rename = "MODEL TYPE")]
model_type: String,
#[tabled(rename = "MODEL NAME")]
name: String,
#[tabled(rename = "NAMESPACE")]
namespace: String,
#[tabled(rename = "COMPONENT")]
component: String,
#[tabled(rename = "ENDPOINT")]
endpoint: String,
}
async fn list_single_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
model_name
);
let mut models = Vec::new();
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&path).await?;
for kv in kvs {
if let (Ok(_key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: model_type.as_str().to_string(),
name: model_name.clone(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
}
if models.is_empty() {
println!("Something went wrong, no model was found.");
} else {
let table = tabled::Table::new(models);
println!("{}", table);
}
Ok(())
}
async fn list_models(
distributed: &DistributedRuntime,
namespace: String,
model_type: Option<ModelType>,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let mut models = Vec::new();
let model_types = match model_type {
Some(mt) => vec![mt],
None => ModelType::all(),
};
for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),);
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
for kv in kvs {
if let (Ok(key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: mt.as_str().to_string(),
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
} }
} }
} }
if models.is_empty() {
match &model_type {
Some(mt) => println!(
"No {} models found in the public namespace: {}",
mt.as_str(),
namespace
),
None => println!("No models found in the public namespace: {}", namespace),
}
} else {
let table = tabled::Table::new(models);
match &model_type {
Some(mt) => println!(
"Listing {} models in the public namespace: {}",
mt.as_str(),
namespace
),
None => println!("Listing all models in the public namespace: {}", namespace),
}
println!("{}", table);
}
Ok(())
}
async fn remove_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
name: &str,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
name
);
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!(
"{} model {} removed from the public namespace: {}",
model_type.as_str(),
name,
namespace
);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
}
}
Ok(()) Ok(())
} }
...@@ -2806,7 +2806,7 @@ dependencies = [ ...@@ -2806,7 +2806,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
"strum", "strum 0.26.3",
"sysinfo", "sysinfo",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokenizers", "tokenizers",
...@@ -4463,7 +4463,16 @@ version = "0.26.3" ...@@ -4463,7 +4463,16 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [ dependencies = [
"strum_macros", "strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
] ]
[[package]] [[package]]
...@@ -4479,6 +4488,19 @@ dependencies = [ ...@@ -4479,6 +4488,19 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -5179,6 +5201,7 @@ dependencies = [ ...@@ -5179,6 +5201,7 @@ dependencies = [
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.27.1",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
"tokio", "tokio",
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_llm::{ use triton_distributed_llm::{
backend::Backend, backend::Backend,
http::service::discovery::ModelEntry,
model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
...@@ -80,6 +81,7 @@ pub async fn run( ...@@ -80,6 +81,7 @@ pub async fn run(
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: service_name.to_string(), name: service_name.to_string(),
endpoint: endpoint.clone(), endpoint: endpoint.clone(),
model_type: ModelType::Chat,
}; };
etcd_client etcd_client
.kv_create( .kv_create(
......
...@@ -18,6 +18,7 @@ use std::sync::Arc; ...@@ -18,6 +18,7 @@ use std::sync::Arc;
use triton_distributed_llm::{ use triton_distributed_llm::{
backend::Backend, backend::Backend,
http::service::{discovery, service_v2}, http::service::{discovery, service_v2},
model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
...@@ -49,6 +50,7 @@ pub async fn run( ...@@ -49,6 +50,7 @@ pub async fn run(
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState { let state = Arc::new(discovery::ModelWatchState {
prefix: service_name.clone(), prefix: service_name.clone(),
model_type: ModelType::Chat, // Tio currently supports only chat models
manager: http_service.model_manager().clone(), manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(), drt: distributed_runtime.clone(),
}); });
......
...@@ -2996,6 +2996,28 @@ version = "0.11.1" ...@@ -2996,6 +2996,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.96",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -3539,6 +3561,7 @@ dependencies = [ ...@@ -3539,6 +3561,7 @@ dependencies = [
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
"tokio", "tokio",
......
...@@ -44,4 +44,4 @@ tracing = "0" ...@@ -44,4 +44,4 @@ tracing = "0"
libc = "0.2" libc = "0.2"
uuid = { version = "1", features = ["v4", "serde"] } uuid = { version = "1", features = ["v4", "serde"] }
async-once-cell = "0.5.4" async-once-cell = "0.5.4"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
\ No newline at end of file
...@@ -18,8 +18,6 @@ use libc::c_char; ...@@ -18,8 +18,6 @@ use libc::c_char;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use std::ffi::CStr; use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log;
use uuid::Uuid;
use triton_distributed_llm::kv_router::{ use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher, indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
...@@ -39,7 +37,7 @@ fn initialize_tracing() { ...@@ -39,7 +37,7 @@ fn initialize_tracing() {
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
log::debug!("Tracing initialized"); tracing::debug!("Tracing initialized");
} }
#[repr(u32)] #[repr(u32)]
...@@ -141,7 +139,7 @@ fn triton_create_kv_publisher( ...@@ -141,7 +139,7 @@ fn triton_create_kv_publisher(
component: String, component: String,
worker_id: i64, worker_id: i64,
) -> Result<KvEventPublisher, anyhow::Error> { ) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component); tracing::info!("Creating KV Publisher for model: {}", component);
match DRT match DRT
.get() .get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime")) .ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
...@@ -197,7 +195,7 @@ fn kv_event_create_stored_from_parts( ...@@ -197,7 +195,7 @@ fn kv_event_create_stored_from_parts(
}) })
.is_ok() .is_ok()
{ {
log::warn!( tracing::warn!(
"Block size must be 64 tokens to be published. Block size is: {}", "Block size must be 64 tokens to be published. Block size is: {}",
num_toks num_toks
); );
......
...@@ -93,9 +93,9 @@ dependencies = [ ...@@ -93,9 +93,9 @@ dependencies = [
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.96" version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
...@@ -379,16 +379,15 @@ checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" ...@@ -379,16 +379,15 @@ checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]] [[package]]
name = "blake3" name = "blake3"
version = "1.6.0" version = "1.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937" checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e"
dependencies = [ dependencies = [
"arrayref", "arrayref",
"arrayvec", "arrayvec",
"cc", "cc",
"cfg-if 1.0.0", "cfg-if 1.0.0",
"constant_time_eq", "constant_time_eq",
"memmap2",
] ]
[[package]] [[package]]
...@@ -451,9 +450,9 @@ dependencies = [ ...@@ -451,9 +450,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.15" version = "1.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
...@@ -505,18 +504,18 @@ dependencies = [ ...@@ -505,18 +504,18 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.30" version = "4.5.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" checksum = "3e77c3243bd94243c03672cb5154667347c457ca271254724f9f393aee1c05ff"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
] ]
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.30" version = "4.5.27"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
...@@ -940,9 +939,9 @@ dependencies = [ ...@@ -940,9 +939,9 @@ dependencies = [
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.2" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]] [[package]]
name = "erased-serde" name = "erased-serde"
...@@ -1041,15 +1040,15 @@ dependencies = [ ...@@ -1041,15 +1040,15 @@ dependencies = [
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
version = "0.5.7" version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.0.35" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide",
...@@ -1241,9 +1240,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" ...@@ -1241,9 +1240,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.8" version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
dependencies = [ dependencies = [
"atomic-waker", "atomic-waker",
"bytes", "bytes",
...@@ -1653,6 +1652,15 @@ dependencies = [ ...@@ -1653,6 +1652,15 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.14.0" version = "0.14.0"
...@@ -1765,9 +1773,9 @@ dependencies = [ ...@@ -1765,9 +1773,9 @@ dependencies = [
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.26" version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]] [[package]]
name = "macro_rules_attribute" name = "macro_rules_attribute"
...@@ -1812,15 +1820,6 @@ version = "2.7.4" ...@@ -1812,15 +1820,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memmap2"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "memo-map" name = "memo-map"
version = "0.3.3" version = "0.3.3"
...@@ -2284,9 +2283,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" ...@@ -2284,9 +2283,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.7.1" version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [ dependencies = [
"fixedbitset", "fixedbitset",
"indexmap 2.7.1", "indexmap 2.7.1",
...@@ -2432,9 +2431,9 @@ dependencies = [ ...@@ -2432,9 +2431,9 @@ dependencies = [
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.13.5" version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive",
...@@ -2442,12 +2441,12 @@ dependencies = [ ...@@ -2442,12 +2441,12 @@ dependencies = [
[[package]] [[package]]
name = "prost-build" name = "prost-build"
version = "0.13.5" version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
dependencies = [ dependencies = [
"heck", "heck",
"itertools 0.14.0", "itertools 0.13.0",
"log", "log",
"multimap", "multimap",
"once_cell", "once_cell",
...@@ -2462,12 +2461,12 @@ dependencies = [ ...@@ -2462,12 +2461,12 @@ dependencies = [
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.13.5" version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.14.0", "itertools 0.13.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.98", "syn 2.0.98",
...@@ -2475,9 +2474,9 @@ dependencies = [ ...@@ -2475,9 +2474,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.13.5" version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc"
dependencies = [ dependencies = [
"prost", "prost",
] ]
...@@ -2661,9 +2660,9 @@ dependencies = [ ...@@ -2661,9 +2660,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.9" version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
dependencies = [ dependencies = [
"bitflags 2.8.0", "bitflags 2.8.0",
] ]
...@@ -2725,14 +2724,15 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" ...@@ -2725,14 +2724,15 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.11" version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73" checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [ dependencies = [
"cc", "cc",
"cfg-if 1.0.0", "cfg-if 1.0.0",
"getrandom 0.2.15", "getrandom 0.2.15",
"libc", "libc",
"spin",
"untrusted", "untrusted",
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
...@@ -2767,9 +2767,9 @@ dependencies = [ ...@@ -2767,9 +2767,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.23" version = "0.23.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7"
dependencies = [ dependencies = [
"log", "log",
"once_cell", "once_cell",
...@@ -2920,18 +2920,18 @@ dependencies = [ ...@@ -2920,18 +2920,18 @@ dependencies = [
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.218" version = "1.0.217"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.218" version = "1.0.217"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -2940,9 +2940,9 @@ dependencies = [ ...@@ -2940,9 +2940,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.139" version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
...@@ -3069,9 +3069,9 @@ dependencies = [ ...@@ -3069,9 +3069,9 @@ dependencies = [
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "1.14.0" version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]] [[package]]
name = "socket2" name = "socket2"
...@@ -3083,6 +3083,12 @@ dependencies = [ ...@@ -3083,6 +3083,12 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
...@@ -3117,6 +3123,28 @@ version = "0.11.1" ...@@ -3117,6 +3123,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -3183,9 +3211,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" ...@@ -3183,9 +3211,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.17.1" version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"fastrand", "fastrand",
...@@ -3660,6 +3688,7 @@ dependencies = [ ...@@ -3660,6 +3688,7 @@ dependencies = [
"semver", "semver",
"serde", "serde",
"serde_json", "serde_json",
"strum",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
"tokio", "tokio",
...@@ -3764,9 +3793,9 @@ checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e" ...@@ -3764,9 +3793,9 @@ checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.18.0" version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "uncased" name = "uncased"
...@@ -3785,9 +3814,9 @@ checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c" ...@@ -3785,9 +3814,9 @@ checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.17" version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]] [[package]]
name = "unicode-normalization-alignments" name = "unicode-normalization-alignments"
...@@ -3878,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" ...@@ -3878,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.14.0" version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1" checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0"
dependencies = [ dependencies = [
"getrandom 0.3.1", "getrandom 0.3.1",
"serde", "serde",
...@@ -4253,9 +4282,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" ...@@ -4253,9 +4282,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.7.3" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
......
...@@ -77,7 +77,7 @@ impl KvMetricsPublisher { ...@@ -77,7 +77,7 @@ impl KvMetricsPublisher {
let rs_publisher = self.inner.clone(); let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone(); let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher rs_publisher
.create_service(rs_component) .create_service(rs_component)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -85,9 +85,9 @@ impl KvMetricsPublisher { ...@@ -85,9 +85,9 @@ impl KvMetricsPublisher {
}) })
} }
fn publish<'p>( fn publish(
&self, &self,
py: Python<'p>, _py: Python,
request_active_slots: u64, request_active_slots: u64,
request_total_slots: u64, request_total_slots: u64,
kv_active_blocks: u64, kv_active_blocks: u64,
......
...@@ -2905,7 +2905,7 @@ dependencies = [ ...@@ -2905,7 +2905,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
"strum", "strum 0.26.3",
"sysinfo", "sysinfo",
"thiserror 1.0.69", "thiserror 1.0.69",
"tokenizers", "tokenizers",
...@@ -4725,7 +4725,16 @@ version = "0.26.3" ...@@ -4725,7 +4725,16 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [ dependencies = [
"strum_macros", "strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
] ]
[[package]] [[package]]
...@@ -4741,6 +4750,19 @@ dependencies = [ ...@@ -4741,6 +4750,19 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]] [[package]]
name = "subtle" name = "subtle"
version = "2.6.1" version = "2.6.1"
...@@ -5438,6 +5460,7 @@ dependencies = [ ...@@ -5438,6 +5460,7 @@ dependencies = [
"sentencepiece", "sentencepiece",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.27.1",
"tempfile", "tempfile",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokenizers", "tokenizers",
......
...@@ -55,6 +55,7 @@ tracing = { version = "0.1" } ...@@ -55,6 +55,7 @@ tracing = { version = "0.1" }
validator = { version = "0.20.0", features = ["derive"] } validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] } uuid = { version = "1", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] } xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
strum = { version = "0.27", features = ["derive"] }
[dependencies] [dependencies]
...@@ -77,6 +78,7 @@ tracing = { workspace = true } ...@@ -77,6 +78,7 @@ tracing = { workspace = true }
validator = { workspace = true } validator = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
xxhash-rust = { workspace = true } xxhash-rust = { workspace = true }
strum = { workspace = true }
blake3 = "1" blake3 = "1"
......
...@@ -20,15 +20,18 @@ use tokio::sync::mpsc::Receiver; ...@@ -20,15 +20,18 @@ use tokio::sync::mpsc::Receiver;
use triton_distributed_runtime::{ use triton_distributed_runtime::{
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent}, transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result, DistributedRuntime, Result,
}; };
use super::ModelManager; use super::ModelManager;
use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionRequest, ChatCompletionResponseDelta,
}; };
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models /// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster. /// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
...@@ -40,10 +43,14 @@ pub struct ModelEntry { ...@@ -40,10 +43,14 @@ pub struct ModelEntry {
/// Component of the endpoint. /// Component of the endpoint.
pub endpoint: protocols::Endpoint, pub endpoint: protocols::Endpoint,
/// Specifies whether the model is a chat or completion model.s
pub model_type: ModelType,
} }
pub struct ModelWatchState { pub struct ModelWatchState {
pub prefix: String, pub prefix: String,
pub model_type: ModelType,
pub manager: ModelManager, pub manager: ModelManager,
pub drt: DistributedRuntime, pub drt: DistributedRuntime,
} }
...@@ -56,26 +63,19 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc ...@@ -56,26 +63,19 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
match event { match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await { WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
Ok(model_name) => { Ok((model_name, model_type)) => {
tracing::info!("added chat model: {}", model_name); tracing::info!("added {} model: {}", model_type, model_name);
} }
Err(e) => { Err(e) => {
tracing::error!("error adding chat model: {}", e); tracing::error!("error adding model: {}", e);
// tracing::warn!(
// "deleting offending key: {}",
// kv.key_str().unwrap_or_default()
// );
// if let Err(e) = kv_client.delete(kv.key(), None).await {
// tracing::error!("failed to delete offending key: {}", e);
// }
} }
}, },
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await { WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => { Ok((model_name, model_type)) => {
tracing::info!("removed chat model: {}", model_name); tracing::info!("removed {} model: {}", model_type, model_name);
} }
Err(e) => { Err(e) => {
tracing::error!("error removing chat model: {}", e); tracing::error!("error removing model: {}", e);
} }
}, },
} }
...@@ -84,33 +84,35 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc ...@@ -84,33 +84,35 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
tracing::debug!("model watcher stopped"); tracing::debug!("model watcher stopped");
} }
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> { async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("removing model"); tracing::debug!("removing model");
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!("key: {}", key); tracing::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix); let model_name = key.trim_start_matches(&state.prefix);
state.manager.remove_chat_completions_model(model_name)?;
Ok(model_name.to_string()) match state.model_type {
ModelType::Chat => state.manager.remove_chat_completions_model(model_name)?,
ModelType::Completion => state.manager.remove_completions_model(model_name)?,
};
Ok((model_name, state.model_type))
} }
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models. // models.
// //
// If this method errors, for the near term, we will delete the offending key. // If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> { async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("adding model"); tracing::debug!("adding model");
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!("key: {}", key); tracing::debug!("key: {}", key);
//let model_name = key.trim_start_matches(&state.prefix); let model_name = key.trim_start_matches(&state.prefix);
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?; let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
/*
// this means there is an entry in etcd that breaks the contract that the key
// in the models path must match the model name in the entry.
if model_entry.name != model_name { if model_entry.name != model_name {
raise!( raise!(
"model name mismatch: {} != {}", "model name mismatch: {} != {}",
...@@ -118,23 +120,40 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String ...@@ -118,23 +120,40 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String
model_name model_name
); );
} }
*/ if model_entry.model_type != state.model_type {
raise!(
let client = state "model type mismatch: {} != {}",
.drt model_entry.model_type,
.namespace(model_entry.endpoint.namespace)? state.model_type
.component(model_entry.endpoint.component)? );
.endpoint(model_entry.endpoint.name) }
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
let client = Arc::new(client);
let model_name = model_entry.name.clone(); match state.model_type {
tracing::info!("New model registered: {model_name}"); ModelType::Chat => {
state let client = state
.manager .drt
.add_chat_completions_model(&model_name, client)?; .namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
state
.manager
.add_chat_completions_model(model_name, Arc::new(client))?;
}
ModelType::Completion => {
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<CompletionRequest, Annotated<CompletionResponse>>()
.await?;
state
.manager
.add_completions_model(model_name, Arc::new(client))?;
}
}
Ok(model_name.to_string()) Ok((model_name, state.model_type))
} }
...@@ -99,14 +99,14 @@ impl HttpServiceConfigBuilder { ...@@ -99,14 +99,14 @@ impl HttpServiceConfigBuilder {
]; ];
if config.enable_chat_endpoints { if config.enable_chat_endpoints {
routes.push(super::openai::completions_router( routes.push(super::openai::chat_completions_router(
model_manager.state(), model_manager.state(),
None, None,
)); ));
} }
if config.enable_cmpl_endpoints { if config.enable_cmpl_endpoints {
routes.push(super::openai::chat_completions_router( routes.push(super::openai::completions_router(
model_manager.state(), model_manager.state(),
None, None,
)); ));
......
...@@ -932,8 +932,8 @@ mod tests { ...@@ -932,8 +932,8 @@ mod tests {
fn test_radix_tree() { fn test_radix_tree() {
let mut trie = RadixTree::new(); let mut trie = RadixTree::new();
let worker_1 = uuid::Uuid::new_v4(); let worker_1 = 0;
let worker_2 = uuid::Uuid::new_v4(); let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None)); trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None));
...@@ -1135,8 +1135,8 @@ mod tests { ...@@ -1135,8 +1135,8 @@ mod tests {
fn test_remove_worker() { fn test_remove_worker() {
let mut trie = RadixTree::new(); let mut trie = RadixTree::new();
let worker_0 = uuid::Uuid::new_v4(); let worker_0 = 0;
let worker_1 = uuid::Uuid::new_v4(); let worker_1 = 1;
assert!(trie assert!(trie
.find_matches(vec![LocalBlockHash(0)], false) .find_matches(vec![LocalBlockHash(0)], false)
...@@ -1159,8 +1159,8 @@ mod tests { ...@@ -1159,8 +1159,8 @@ mod tests {
fn test_early_stopping() { fn test_early_stopping() {
let mut trie = RadixTree::new(); let mut trie = RadixTree::new();
let worker_0 = uuid::Uuid::new_v4(); let worker_0 = 0;
let worker_1 = uuid::Uuid::new_v4(); let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None)); trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None));
trie.apply_event(create_store_event(worker_1, 0, vec![0], None)); trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
...@@ -1276,7 +1276,7 @@ mod tests { ...@@ -1276,7 +1276,7 @@ mod tests {
#[case(8)] #[case(8)]
#[tokio::test] #[tokio::test]
async fn test_apply_event(#[case] num_shards: usize) { async fn test_apply_event(#[case] num_shards: usize) {
let worker_id = uuid::Uuid::new_v4(); let worker_id = 0;
let token = CancellationToken::new(); let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards); let mut kv_indexer = make_indexer(&token, num_shards);
...@@ -1327,7 +1327,7 @@ mod tests { ...@@ -1327,7 +1327,7 @@ mod tests {
)); ));
} }
let worker_id = uuid::Uuid::new_v4(); let worker_id = 0;
let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None); let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
kv_indexer.apply_event(event).await; kv_indexer.apply_event(event).await;
...@@ -1367,7 +1367,7 @@ mod tests { ...@@ -1367,7 +1367,7 @@ mod tests {
#[test] #[test]
fn test_router_event_new() { fn test_router_event_new() {
let worker_id = uuid::Uuid::new_v4(); let worker_id = 0;
let kv_cache_event = KvCacheEvent { let kv_cache_event = KvCacheEvent {
event_id: 1, event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
......
...@@ -24,6 +24,7 @@ pub mod engines; ...@@ -24,6 +24,7 @@ pub mod engines;
pub mod http; pub mod http;
pub mod kv_router; pub mod kv_router;
pub mod model_card; pub mod model_card;
pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
pub mod protocols; pub mod protocols;
pub mod tokenizers; pub mod tokenizers;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelType {
Chat,
Completion,
}
impl ModelType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
}
}
pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion]
}
}
...@@ -148,11 +148,16 @@ where ...@@ -148,11 +148,16 @@ where
}) })
} }
/// String identifying namepoint/component/endpoint /// String identifying <namespace>/<component>/<endpoint>
pub fn path(&self) -> String { pub fn path(&self) -> String {
self.endpoint.path() self.endpoint.path()
} }
/// String identifying <namespace>/component/<component>/<endpoint>
pub fn etcd_path(&self) -> String {
self.endpoint.etcd_path()
}
pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> { pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
&self.watch_rx &self.watch_rx
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment