"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "7f977ea3ab5638bb0e927fbdbc21f7693a71563a"
Commit b760c569 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: Add completion endpoint to http server and llmctl (#230)


Co-authored-by: default avataraflowers <aflowers@nvidia.com>
parent 113f4d91
......@@ -2867,6 +2867,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -3410,6 +3432,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
......
......@@ -3016,6 +3016,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -3582,6 +3604,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
......
......@@ -16,11 +16,16 @@
use clap::Parser;
use std::sync::Arc;
use triton_distributed_llm::http::service::{
discovery::{model_watcher, ModelWatchState},
service_v2::HttpService,
use triton_distributed_llm::{
http::service::{
discovery::{model_watcher, ModelWatchState},
service_v2::HttpService,
},
model_type::ModelType,
};
use triton_distributed_runtime::{
logging, transports::etcd::PrefixWatcher, DistributedRuntime, Result, Runtime, Worker,
};
use triton_distributed_runtime::{logging, DistributedRuntime, Result, Runtime, Worker};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
......@@ -50,10 +55,8 @@ fn main() -> Result<()> {
async fn app(runtime: Runtime) -> Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let args = Args::parse();
// create the http service and acquire the model manager
let http_service = HttpService::builder()
.port(args.port)
.host(args.host)
......@@ -71,21 +74,29 @@ async fn app(runtime: Runtime) -> Result<()> {
let component = distributed
.namespace(&args.namespace)?
.component(&args.component)?;
let etcd_root = component.etcd_path();
let etcd_path = format!("{}/models/chat/", etcd_root);
let state = Arc::new(ModelWatchState {
prefix: etcd_path.clone(),
manager,
drt: distributed.clone(),
});
// Create watchers for all model types
let mut watcher_tasks = Vec::new();
for model_type in ModelType::all() {
let etcd_path = format!("{}/models/{}/", etcd_root, model_type.as_str());
let state = Arc::new(ModelWatchState {
prefix: etcd_path.clone(),
model_type,
manager: manager.clone(),
drt: distributed.clone(),
});
let etcd_client = distributed.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let etcd_client = distributed.etcd_client();
let models_watcher: PrefixWatcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(model_watcher(state, receiver));
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let watcher_task = tokio::spawn(model_watcher(state, receiver));
watcher_tasks.push(watcher_task);
}
// Run the service
http_service.run(runtime.child_token()).await
}
......@@ -16,18 +16,96 @@
use clap::{Parser, Subcommand};
use tracing as log;
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_llm::{http::service::discovery::ModelEntry, model_type::ModelType};
use triton_distributed_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker,
};
// Macro to define model types and associated commands
macro_rules! define_type_subcommands {
($(($variant:ident, $primary_name:expr, [$($alias:expr),*], $help:expr)),* $(,)?) => {
#[derive(Subcommand)]
enum AddCommands {
$(
#[doc = $help]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant(AddModelArgs),
)*
}
#[derive(Subcommand)]
enum ListCommands {
$(
#[doc = concat!("List ", $primary_name, " models")]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant,
)*
}
#[derive(Subcommand)]
enum RemoveCommands {
$(
#[doc = concat!("Remove ", $primary_name, " model")]
#[command(name = $primary_name, aliases = [$($alias),*])]
$variant(RemoveModelArgs),
)*
}
impl AddCommands {
fn into_parts(self) -> (ModelType, String, String) {
match self {
$(Self::$variant(args) => (ModelType::$variant, args.model_name, args.endpoint_name)),*
}
}
}
impl RemoveCommands {
fn into_parts(self) -> (ModelType, String) {
match self {
$(Self::$variant(args) => (ModelType::$variant, args.model_name)),*
}
}
}
impl ListCommands {
fn model_type(&self) -> ModelType {
match self {
$(Self::$variant => ModelType::$variant),*
}
}
}
}
}
define_type_subcommands!(
(
Chat,
"chat",
["chat-model", "chat-models"],
"Add a chat model"
),
(
Completion,
"completion",
["completions", "completion-model"],
"Add a completion model"
),
// Add new model types here:
);
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(
author="NVIDIA",
version="0.2.1",
about="LLMCTL - Control and manage TRD Components",
long_about = None,
disable_help_subcommand = true,
)]
struct Cli {
/// Namespace to operate in
/// Public Namespace to operate in
#[arg(short = 'n', long)]
namespace: Option<String>,
public_namespace: Option<String>,
#[command(subcommand)]
command: Commands,
......@@ -44,42 +122,41 @@ enum Commands {
#[derive(Subcommand)]
enum HttpCommands {
/// Add a chat model
/// Add models
Add {
/// Specifies we're adding a chat model
#[arg(value_name = "chat-model")]
chat_model: String,
/// Model name (e.g. foo/v1)
model_name: String,
/// Endpoint name (format: component.endpoint or namespace.component.endpoint)
endpoint_name: String,
#[command(subcommand)]
model_type: AddCommands,
},
/// List chat models
/// List models (all types if no specific type provided)
List {
/// Specifies we're listing chat models
#[arg(value_name = "chat-model", value_parser = parse_chat_model)]
chat_model: String,
#[command(subcommand)]
model_type: Option<ListCommands>,
},
/// Remove a chat model
/// Remove models
Remove {
/// Specifies we're removing a chat model
#[arg(value_name = "chat-model")]
chat_model: String,
/// Name of the model to remove
name: String,
#[command(subcommand)]
model_type: RemoveCommands,
},
}
fn parse_chat_model(s: &str) -> Result<String> {
match s {
"chat-model" | "chat-models" => Ok(s.to_string()),
_ => raise!("Expected 'chat-model' or 'chat-models'"),
}
#[derive(Parser)]
struct AddModelArgs {
/// Model name (e.g. foo/v1)
#[arg(name = "model-name")]
model_name: String,
/// Endpoint name (format: component.endpoint or namespace.component.endpoint)
#[arg(name = "endpoint-name")]
endpoint_name: String,
}
/// Common fields for removing any model type
#[derive(Parser)]
struct RemoveModelArgs {
/// Name of the model to remove
#[arg(name = "model-name")]
model_name: String,
}
fn main() -> Result<()> {
......@@ -87,7 +164,7 @@ fn main() -> Result<()> {
let cli = Cli::parse();
// Default namespace to "public" if not specified
let namespace = cli.namespace.unwrap_or_else(|| "public".to_string());
let namespace = cli.public_namespace.unwrap_or_else(|| "public".to_string());
let worker = Worker::from_settings()?;
worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
......@@ -100,136 +177,262 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
match command {
Commands::Http { command } => {
match command {
HttpCommands::Add {
chat_model: _,
model_name,
endpoint_name,
} => {
log::debug!(
"Adding model {} with endpoint {}",
HttpCommands::Add { model_type } => {
let (model_type, model_name, endpoint_name) = model_type.into_parts();
add_model(
&distributed,
namespace.to_string(),
model_type,
model_name,
endpoint_name
);
// parse endpoint
// split by '.' must have 2, can have 3 parts, any more or less is an error
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 || parts.len() > 3 {
raise!("Invalid endpoint name: {}", endpoint_name);
}
// if 3 parts, then it's namespace.component.endpoint
// if 2 parts, then it's model_name.component.endpoint
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry {
name: model_name.clone(),
endpoint,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!("{}/models/chat/{}", component.etcd_path(), model_name);
let etcd_client = distributed.etcd_client();
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Model {} added to namespace {}", model_name, namespace);
&endpoint_name,
)
.await?;
}
HttpCommands::List { chat_model: _ } => {
let component = distributed.namespace(&namespace)?.component("http")?;
// todo - make this part of the http discovery service object
let prefix = format!("{}/models/chat/", component.etcd_path());
// get the kvs from etcd
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
use tabled::Tabled;
#[derive(Tabled)]
struct ModelRow {
#[tabled(rename = "MODEL NAME")]
name: String,
#[tabled(rename = "NAMESPACE")]
namespace: String,
#[tabled(rename = "COMPONENT")]
component: String,
#[tabled(rename = "ENDPOINT")]
endpoint: String,
}
// parse the keys
let mut models = Vec::new();
for kv in kvs {
match (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
(Ok(key), Ok(model)) => {
models.push(ModelRow {
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
(Err(e), _) => {
log::debug!("Error parsing key: {}", e);
}
(_, Err(e)) => {
log::debug!("Error parsing value: {}", e);
}
HttpCommands::List { model_type } => {
match model_type {
Some(model_type) => {
list_models(
&distributed,
namespace.clone(),
Some(model_type.model_type()),
)
.await?;
}
}
if models.is_empty() {
println!("No chat models found in namespace {}", namespace);
} else {
let table = tabled::Table::new(models);
println!("Listing chat models in namespace {}", namespace);
println!("{}", table);
}
}
HttpCommands::Remove {
chat_model: _,
name,
} => {
// TODO: Implement remove logic
log::debug!("Removing model {}", name);
let component = distributed.namespace(&namespace)?.component("http")?;
// todo - make this part of the http discovery service object
let prefix = format!("{}/models/chat/{name}", component.etcd_path());
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!("Model {} removed from namespace {}", name, namespace);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
None => {
// List all model types
list_models(&distributed, namespace.clone(), None).await?;
}
}
}
HttpCommands::Remove { model_type } => {
let (model_type, name) = model_type.into_parts();
remove_model(&distributed, namespace.to_string(), model_type, &name).await?;
}
}
}
}
Ok(())
}
// Helper functions to handle the actual operations
async fn add_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
endpoint_name: &str,
) -> Result<()> {
log::debug!(
"Adding model {} with endpoint {}",
model_name,
endpoint_name
);
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
raise!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
raise!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
}
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
println!(
"Using the public namespace: {} for model: {}",
namespace, model_name
);
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry {
name: model_name.to_string(),
endpoint,
model_type,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
model_name
);
let etcd_client = distributed.etcd_client();
// check if model already exists
let kvs = etcd_client.kv_get_prefix(&path).await?;
if !kvs.is_empty() {
println!(
"{} model {} already exists, please remove it before changing the endpoint.",
model_type.as_str(),
model_name,
);
list_single_model(distributed, namespace, model_type, model_name).await?;
} else {
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Added new {} model {}", model_type.as_str(), model_name,);
list_single_model(distributed, namespace, model_type, model_name).await?;
}
Ok(())
}
#[derive(tabled::Tabled)]
struct ModelRow {
#[tabled(rename = "MODEL TYPE")]
model_type: String,
#[tabled(rename = "MODEL NAME")]
name: String,
#[tabled(rename = "NAMESPACE")]
namespace: String,
#[tabled(rename = "COMPONENT")]
component: String,
#[tabled(rename = "ENDPOINT")]
endpoint: String,
}
async fn list_single_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
model_name
);
let mut models = Vec::new();
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&path).await?;
for kv in kvs {
if let (Ok(_key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: model_type.as_str().to_string(),
name: model_name.clone(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
}
if models.is_empty() {
println!("Something went wrong, no model was found.");
} else {
let table = tabled::Table::new(models);
println!("{}", table);
}
Ok(())
}
async fn list_models(
distributed: &DistributedRuntime,
namespace: String,
model_type: Option<ModelType>,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let mut models = Vec::new();
let model_types = match model_type {
Some(mt) => vec![mt],
None => ModelType::all(),
};
for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),);
let etcd_client = distributed.etcd_client();
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
for kv in kvs {
if let (Ok(key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: mt.as_str().to_string(),
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
}
}
if models.is_empty() {
match &model_type {
Some(mt) => println!(
"No {} models found in the public namespace: {}",
mt.as_str(),
namespace
),
None => println!("No models found in the public namespace: {}", namespace),
}
} else {
let table = tabled::Table::new(models);
match &model_type {
Some(mt) => println!(
"Listing {} models in the public namespace: {}",
mt.as_str(),
namespace
),
None => println!("Listing all models in the public namespace: {}", namespace),
}
println!("{}", table);
}
Ok(())
}
async fn remove_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
name: &str,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!(
"{}/models/{}/{}",
component.etcd_path(),
model_type.as_str(),
name
);
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!(
"{} model {} removed from the public namespace: {}",
model_type.as_str(),
name,
namespace
);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
}
}
Ok(())
}
......@@ -2806,7 +2806,7 @@ dependencies = [
"serde_json",
"serde_plain",
"serde_yaml",
"strum",
"strum 0.26.3",
"sysinfo",
"thiserror 1.0.69",
"tokenizers",
......@@ -4463,7 +4463,16 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros",
"strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
]
[[package]]
......@@ -4479,6 +4488,19 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -5179,6 +5201,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"strum 0.27.1",
"thiserror 2.0.11",
"tokenizers",
"tokio",
......
......@@ -13,9 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use triton_distributed_llm::http::service::discovery::ModelEntry;
use triton_distributed_llm::{
backend::Backend,
http::service::discovery::ModelEntry,
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
......@@ -80,6 +81,7 @@ pub async fn run(
let model_registration = ModelEntry {
name: service_name.to_string(),
endpoint: endpoint.clone(),
model_type: ModelType::Chat,
};
etcd_client
.kv_create(
......
......@@ -18,6 +18,7 @@ use std::sync::Arc;
use triton_distributed_llm::{
backend::Backend,
http::service::{discovery, service_v2},
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
......@@ -49,6 +50,7 @@ pub async fn run(
// Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState {
prefix: service_name.clone(),
model_type: ModelType::Chat, // Tio currently supports only chat models
manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(),
});
......
......@@ -2996,6 +2996,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.96",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -3539,6 +3561,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
......
......@@ -44,4 +44,4 @@ tracing = "0"
libc = "0.2"
uuid = { version = "1", features = ["v4", "serde"] }
async-once-cell = "0.5.4"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
\ No newline at end of file
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
......@@ -18,8 +18,6 @@ use libc::c_char;
use once_cell::sync::OnceCell;
use std::ffi::CStr;
use std::sync::atomic::{AtomicU32, Ordering};
use tracing as log;
use uuid::Uuid;
use triton_distributed_llm::kv_router::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher,
......@@ -39,7 +37,7 @@ fn initialize_tracing() {
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
log::debug!("Tracing initialized");
tracing::debug!("Tracing initialized");
}
#[repr(u32)]
......@@ -141,7 +139,7 @@ fn triton_create_kv_publisher(
component: String,
worker_id: i64,
) -> Result<KvEventPublisher, anyhow::Error> {
log::info!("Creating KV Publisher for model: {}", component);
tracing::info!("Creating KV Publisher for model: {}", component);
match DRT
.get()
.ok_or(anyhow::Error::msg("Could not get Distributed Runtime"))
......@@ -197,7 +195,7 @@ fn kv_event_create_stored_from_parts(
})
.is_ok()
{
log::warn!(
tracing::warn!(
"Block size must be 64 tokens to be published. Block size is: {}",
num_toks
);
......
......@@ -93,9 +93,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.96"
version = "1.0.95"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4"
checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04"
[[package]]
name = "arrayref"
......@@ -379,16 +379,15 @@ checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "blake3"
version = "1.6.0"
version = "1.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937"
checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e"
dependencies = [
"arrayref",
"arrayvec",
"cc",
"cfg-if 1.0.0",
"constant_time_eq",
"memmap2",
]
[[package]]
......@@ -451,9 +450,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.15"
version = "1.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af"
checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda"
dependencies = [
"jobserver",
"libc",
......@@ -505,18 +504,18 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.30"
version = "4.5.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d"
checksum = "3e77c3243bd94243c03672cb5154667347c457ca271254724f9f393aee1c05ff"
dependencies = [
"clap_builder",
]
[[package]]
name = "clap_builder"
version = "4.5.30"
version = "4.5.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c"
checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7"
dependencies = [
"anstream",
"anstyle",
......@@ -940,9 +939,9 @@ dependencies = [
[[package]]
name = "equivalent"
version = "1.0.2"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "erased-serde"
......@@ -1041,15 +1040,15 @@ dependencies = [
[[package]]
name = "fixedbitset"
version = "0.5.7"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.35"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c"
checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc"
dependencies = [
"crc32fast",
"miniz_oxide",
......@@ -1241,9 +1240,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "h2"
version = "0.4.8"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2"
checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e"
dependencies = [
"atomic-waker",
"bytes",
......@@ -1653,6 +1652,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
......@@ -1765,9 +1773,9 @@ dependencies = [
[[package]]
name = "log"
version = "0.4.26"
version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "macro_rules_attribute"
......@@ -1812,15 +1820,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memmap2"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
dependencies = [
"libc",
]
[[package]]
name = "memo-map"
version = "0.3.3"
......@@ -2284,9 +2283,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "petgraph"
version = "0.7.1"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772"
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
dependencies = [
"fixedbitset",
"indexmap 2.7.1",
......@@ -2432,9 +2431,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.13.5"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5"
checksum = "2c0fef6c4230e4ccf618a35c59d7ede15dea37de8427500f50aff708806e42ec"
dependencies = [
"bytes",
"prost-derive",
......@@ -2442,12 +2441,12 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.13.5"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
dependencies = [
"heck",
"itertools 0.14.0",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
......@@ -2462,12 +2461,12 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.13.5"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
dependencies = [
"anyhow",
"itertools 0.14.0",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn 2.0.98",
......@@ -2475,9 +2474,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.13.5"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16"
checksum = "cc2f1e56baa61e93533aebc21af4d2134b70f66275e0fcdf3cbe43d77ff7e8fc"
dependencies = [
"prost",
]
......@@ -2661,9 +2660,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.9"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b568323e98e49e2a0899dcee453dd679fae22d69adf9b11dd508d1549b7e2f"
checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834"
dependencies = [
"bitflags 2.8.0",
]
......@@ -2725,14 +2724,15 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "ring"
version = "0.17.11"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if 1.0.0",
"getrandom 0.2.15",
"libc",
"spin",
"untrusted",
"windows-sys 0.52.0",
]
......@@ -2767,9 +2767,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.23"
version = "0.23.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7"
dependencies = [
"log",
"once_cell",
......@@ -2920,18 +2920,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.218"
version = "1.0.217"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
version = "1.0.217"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
dependencies = [
"proc-macro2",
"quote",
......@@ -2940,9 +2940,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.139"
version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
dependencies = [
"itoa",
"memchr",
......@@ -3069,9 +3069,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.14.0"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "socket2"
......@@ -3083,6 +3083,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "spki"
version = "0.7.3"
......@@ -3117,6 +3123,28 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -3183,9 +3211,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
version = "3.17.1"
version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230"
checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91"
dependencies = [
"cfg-if 1.0.0",
"fastrand",
......@@ -3660,6 +3688,7 @@ dependencies = [
"semver",
"serde",
"serde_json",
"strum",
"thiserror 2.0.11",
"tokenizers",
"tokio",
......@@ -3764,9 +3793,9 @@ checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.18.0"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "uncased"
......@@ -3785,9 +3814,9 @@ checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c"
[[package]]
name = "unicode-ident"
version = "1.0.17"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe"
checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "unicode-normalization-alignments"
......@@ -3878,9 +3907,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.14.0"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1"
checksum = "ced87ca4be083373936a67f8de945faa23b6b42384bd5b64434850802c6dccd0"
dependencies = [
"getrandom 0.3.1",
"serde",
......@@ -4253,9 +4282,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.3"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1"
checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
dependencies = [
"memchr",
]
......
......@@ -77,7 +77,7 @@ impl KvMetricsPublisher {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let _ = rs_publisher
rs_publisher
.create_service(rs_component)
.await
.map_err(to_pyerr)?;
......@@ -85,9 +85,9 @@ impl KvMetricsPublisher {
})
}
fn publish<'p>(
fn publish(
&self,
py: Python<'p>,
_py: Python,
request_active_slots: u64,
request_total_slots: u64,
kv_active_blocks: u64,
......
......@@ -2905,7 +2905,7 @@ dependencies = [
"serde_json",
"serde_plain",
"serde_yaml",
"strum",
"strum 0.26.3",
"sysinfo",
"thiserror 1.0.69",
"tokenizers",
......@@ -4725,7 +4725,16 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros",
"strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
]
[[package]]
......@@ -4741,6 +4750,19 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.98",
]
[[package]]
name = "subtle"
version = "2.6.1"
......@@ -5438,6 +5460,7 @@ dependencies = [
"sentencepiece",
"serde",
"serde_json",
"strum 0.27.1",
"tempfile",
"thiserror 2.0.11",
"tokenizers",
......
......@@ -55,6 +55,7 @@ tracing = { version = "0.1" }
validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] }
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
strum = { version = "0.27", features = ["derive"] }
[dependencies]
......@@ -77,6 +78,7 @@ tracing = { workspace = true }
validator = { workspace = true }
uuid = { workspace = true }
xxhash-rust = { workspace = true }
strum = { workspace = true }
blake3 = "1"
......
......@@ -20,15 +20,18 @@ use tokio::sync::mpsc::Receiver;
use triton_distributed_runtime::{
protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result,
};
use super::ModelManager;
use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
......@@ -40,10 +43,14 @@ pub struct ModelEntry {
/// Component of the endpoint.
pub endpoint: protocols::Endpoint,
/// Specifies whether the model is a chat or completion model.s
pub model_type: ModelType,
}
pub struct ModelWatchState {
pub prefix: String,
pub model_type: ModelType,
pub manager: ModelManager,
pub drt: DistributedRuntime,
}
......@@ -56,26 +63,19 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
Ok(model_name) => {
tracing::info!("added chat model: {}", model_name);
Ok((model_name, model_type)) => {
tracing::info!("added {} model: {}", model_type, model_name);
}
Err(e) => {
tracing::error!("error adding chat model: {}", e);
// tracing::warn!(
// "deleting offending key: {}",
// kv.key_str().unwrap_or_default()
// );
// if let Err(e) = kv_client.delete(kv.key(), None).await {
// tracing::error!("failed to delete offending key: {}", e);
// }
tracing::error!("error adding model: {}", e);
}
},
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => {
tracing::info!("removed chat model: {}", model_name);
Ok((model_name, model_type)) => {
tracing::info!("removed {} model: {}", model_type, model_name);
}
Err(e) => {
tracing::error!("error removing chat model: {}", e);
tracing::error!("error removing model: {}", e);
}
},
}
......@@ -84,33 +84,35 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
tracing::debug!("model watcher stopped");
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("removing model");
let key = kv.key_str()?;
tracing::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix);
state.manager.remove_chat_completions_model(model_name)?;
Ok(model_name.to_string())
match state.model_type {
ModelType::Chat => state.manager.remove_chat_completions_model(model_name)?,
ModelType::Completion => state.manager.remove_completions_model(model_name)?,
};
Ok((model_name, state.model_type))
}
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("adding model");
let key = kv.key_str()?;
tracing::debug!("key: {}", key);
//let model_name = key.trim_start_matches(&state.prefix);
let model_name = key.trim_start_matches(&state.prefix);
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
/*
// this means there is an entry in etcd that breaks the contract that the key
// in the models path must match the model name in the entry.
if model_entry.name != model_name {
raise!(
"model name mismatch: {} != {}",
......@@ -118,23 +120,40 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String
model_name
);
}
*/
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
let client = Arc::new(client);
if model_entry.model_type != state.model_type {
raise!(
"model type mismatch: {} != {}",
model_entry.model_type,
state.model_type
);
}
let model_name = model_entry.name.clone();
tracing::info!("New model registered: {model_name}");
state
.manager
.add_chat_completions_model(&model_name, client)?;
match state.model_type {
ModelType::Chat => {
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
state
.manager
.add_chat_completions_model(model_name, Arc::new(client))?;
}
ModelType::Completion => {
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<CompletionRequest, Annotated<CompletionResponse>>()
.await?;
state
.manager
.add_completions_model(model_name, Arc::new(client))?;
}
}
Ok(model_name.to_string())
Ok((model_name, state.model_type))
}
......@@ -99,14 +99,14 @@ impl HttpServiceConfigBuilder {
];
if config.enable_chat_endpoints {
routes.push(super::openai::completions_router(
routes.push(super::openai::chat_completions_router(
model_manager.state(),
None,
));
}
if config.enable_cmpl_endpoints {
routes.push(super::openai::chat_completions_router(
routes.push(super::openai::completions_router(
model_manager.state(),
None,
));
......
......@@ -932,8 +932,8 @@ mod tests {
fn test_radix_tree() {
let mut trie = RadixTree::new();
let worker_1 = uuid::Uuid::new_v4();
let worker_2 = uuid::Uuid::new_v4();
let worker_1 = 0;
let worker_2 = 1;
trie.apply_event(create_store_event(worker_1, 1, vec![1, 2, 3], None));
......@@ -1135,8 +1135,8 @@ mod tests {
fn test_remove_worker() {
let mut trie = RadixTree::new();
let worker_0 = uuid::Uuid::new_v4();
let worker_1 = uuid::Uuid::new_v4();
let worker_0 = 0;
let worker_1 = 1;
assert!(trie
.find_matches(vec![LocalBlockHash(0)], false)
......@@ -1159,8 +1159,8 @@ mod tests {
fn test_early_stopping() {
let mut trie = RadixTree::new();
let worker_0 = uuid::Uuid::new_v4();
let worker_1 = uuid::Uuid::new_v4();
let worker_0 = 0;
let worker_1 = 1;
trie.apply_event(create_store_event(worker_0, 0, vec![0, 1, 2], None));
trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
......@@ -1276,7 +1276,7 @@ mod tests {
#[case(8)]
#[tokio::test]
async fn test_apply_event(#[case] num_shards: usize) {
let worker_id = uuid::Uuid::new_v4();
let worker_id = 0;
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards);
......@@ -1327,7 +1327,7 @@ mod tests {
));
}
let worker_id = uuid::Uuid::new_v4();
let worker_id = 0;
let event = create_store_event(worker_id, 0, vec![1, 2, 3, 4], None);
kv_indexer.apply_event(event).await;
......@@ -1367,7 +1367,7 @@ mod tests {
#[test]
fn test_router_event_new() {
let worker_id = uuid::Uuid::new_v4();
let worker_id = 0;
let kv_cache_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
......
......@@ -24,6 +24,7 @@ pub mod engines;
pub mod http;
pub mod kv_router;
pub mod model_card;
pub mod model_type;
pub mod preprocessor;
pub mod protocols;
pub mod tokenizers;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelType {
Chat,
Completion,
}
impl ModelType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
}
}
pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion]
}
}
......@@ -148,11 +148,16 @@ where
})
}
/// String identifying namepoint/component/endpoint
/// String identifying <namespace>/<component>/<endpoint>
pub fn path(&self) -> String {
self.endpoint.path()
}
/// String identifying <namespace>/component/<component>/<endpoint>
pub fn etcd_path(&self) -> String {
self.endpoint.etcd_path()
}
pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
&self.watch_rx
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment