Commit 32a748e4 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(tio): Distributed inference! (#235)

Add support in tio for distributed components and discovery.

Node 1:
```
tio in=http out=tdr://ns/backend/mistralrs
```

Node 2:
```
tio in=tdr://ns/backend/mistralrs out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct
```

This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.

The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has three parts, and it matches the other node.
parent 3b7a462d
...@@ -44,5 +44,21 @@ Send a request: ...@@ -44,5 +44,21 @@ Send a request:
curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages":[{"role":"user", "content": "What is the capital of South Africa?" }]}' -H 'Content-Type: application/json' http://localhost:8080/v1/chat/completions curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages":[{"role":"user", "content": "What is the capital of South Africa?" }]}' -H 'Content-Type: application/json' http://localhost:8080/v1/chat/completions
``` ```
*Multi-node*
Node 1:
```
tio in=http out=tdr://ns/backend/mistralrs
```
Node 2:
```
tio in=tdr://ns/backend/mistralrs out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct
```
This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.
The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has three parts, and it matches the other node.
Run `tio --help` for more options. Run `tio --help` for more options.
...@@ -13,5 +13,6 @@ ...@@ -13,5 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod endpoint;
pub mod http; pub mod http;
pub mod text; pub mod text;
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use triton_distributed::{
pipeline::network::Ingress, protocols::Endpoint, DistributedRuntime, Runtime,
};
use triton_llm::http::service::discovery::ModelEntry;
use crate::{EngineConfig, ENDPOINT_SCHEME};
pub async fn run(
runtime: Runtime,
path: String,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
// This will attempt to connect to NATS and etcd
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
match engine_config {
EngineConfig::StaticFull {
service_name,
engine,
} => {
let cancel_token = runtime.primary_token().clone();
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
anyhow::bail!("An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint");
}
// Register with etcd
let endpoint = Endpoint {
namespace: elements[0].to_string(),
component: elements[1].to_string(),
name: elements[2].to_string(),
};
let model_registration = ModelEntry {
name: service_name.to_string(),
endpoint,
};
let etcd_client = distributed.etcd_client();
etcd_client
.kv_create(
path.clone(),
serde_json::to_vec_pretty(&model_registration)?,
None,
)
.await?;
// Start the model
let ingress = Ingress::for_engine(engine)?;
let rt_fut = distributed
.namespace(elements[0])?
.component(elements[1])?
.service_builder()
.create()
.await?
.endpoint(elements[2])
.endpoint_builder()
.handler(ingress)
.start();
tokio::select! {
_ = rt_fut => {
tracing::debug!("Endpoint ingress ended");
}
_ = cancel_token.cancelled() => {
}
}
Ok(())
}
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
}
}
}
...@@ -13,32 +13,48 @@ ...@@ -13,32 +13,48 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use triton_distributed::runtime::CancellationToken; use std::sync::Arc;
use triton_llm::http::service::service_v2;
use triton_distributed::{DistributedRuntime, Runtime};
use triton_llm::http::service::{discovery, service_v2};
use crate::EngineConfig; use crate::EngineConfig;
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run( pub async fn run(
cancel_token: CancellationToken, runtime: Runtime,
http_port: u16, http_port: u16,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder()
.port(http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic(client) => {
let service_name = client.path();
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
// Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState {
prefix: service_name.clone(),
manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(),
});
let etcd_client = distributed_runtime.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(service_name).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver));
}
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
engine, engine,
.. ..
} => { } => {
let http_service = service_v2::HttpService::builder()
.port(http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()?;
http_service http_service
.model_manager() .model_manager()
.add_chat_completions_model(&service_name, engine)?; .add_chat_completions_model(&service_name, engine)?;
http_service.run(cancel_token).await
} }
} }
http_service.run(runtime.primary_token()).await
} }
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
// limitations under the License. // limitations under the License.
use futures::StreamExt; use futures::StreamExt;
use std::io::{ErrorKind, Read, Write}; use std::{
io::{ErrorKind, Read, Write},
sync::Arc,
};
use triton_distributed::{pipeline::Context, runtime::CancellationToken}; use triton_distributed::{pipeline::Context, runtime::CancellationToken};
use triton_llm::{ use triton_llm::{
protocols::openai::chat_completions::MessageRole, protocols::openai::chat_completions::MessageRole,
...@@ -40,6 +43,13 @@ pub async fn run( ...@@ -40,6 +43,13 @@ pub async fn run(
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
bool, bool,
) = match engine_config { ) = match engine_config {
EngineConfig::Dynamic(client) => {
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = client.path();
tracing::info!("Model: {service_name}");
(service_name, Arc::new(client), false)
}
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
engine, engine,
...@@ -48,8 +58,6 @@ pub async fn run( ...@@ -48,8 +58,6 @@ pub async fn run(
(service_name, engine, false) (service_name, engine, false)
} }
}; };
// TODO: Acquire an etcd lease, we are running
main_loop(cancel_token, &service_name, engine, inspect_template).await main_loop(cancel_token, &service_name, engine, inspect_template).await
} }
......
...@@ -15,14 +15,24 @@ ...@@ -15,14 +15,24 @@
use std::path::PathBuf; use std::path::PathBuf;
use triton_distributed::runtime::CancellationToken; use triton_distributed::{component::Client, DistributedRuntime};
use triton_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use triton_llm::types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, OpenAIChatCompletionsStreamingEngine,
},
Annotated,
};
mod input; mod input;
mod opt; mod opt;
mod output; mod output;
pub use opt::{Input, Output}; pub use opt::{Input, Output};
/// How we identify a namespace/component/endpoint URL.
/// Technically the '://' is not part of the scheme but it eliminates several string
/// concatenations.
const ENDPOINT_SCHEME: &str = "tdr://";
/// Required options depend on the in and out choices /// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
...@@ -46,6 +56,10 @@ pub struct Flags { ...@@ -46,6 +56,10 @@ pub struct Flags {
} }
pub enum EngineConfig { pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>),
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
service_name: String, service_name: String,
...@@ -54,22 +68,25 @@ pub enum EngineConfig { ...@@ -54,22 +68,25 @@ pub enum EngineConfig {
} }
pub async fn run( pub async fn run(
runtime: triton_distributed::Runtime,
in_opt: Input, in_opt: Input,
out_opt: Output, out_opt: Output,
flags: Flags, flags: Flags,
cancel_token: CancellationToken,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Turn relative paths into absolute paths // Turn relative paths into absolute paths
let model_path = flags let model_path = flags
.model_path_pos .model_path_pos
.or(flags.model_path_flag) .or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok()); .and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file. // Serve the model under the name provided, or the name of the GGUF file.
let model_name = flags.model_name.or_else(|| let model_name = flags.model_name.or_else(|| {
// "stem" means the filename without the extension. model_path
model_path.as_ref() .as_ref()
.and_then(|p| p.file_stem()) .and_then(|p| p.iter().last())
.map(|n| n.to_string_lossy().into_owned())); .map(|n| n.to_string_lossy().into_owned())
});
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
...@@ -84,6 +101,33 @@ pub async fn run( ...@@ -84,6 +101,33 @@ pub async fn run(
engine: output::echo_full::make_engine_full(), engine: output::echo_full::make_engine_full(),
} }
} }
Output::Endpoint(path) => {
let elements: Vec<&str> = path.split('/').collect();
if elements.len() != 3 {
anyhow::bail!("An endpoint URL must have format {ENDPOINT_SCHEME}namespace/component/endpoint");
}
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let client = distributed_runtime
.namespace(elements[0])?
.component(elements[1])?
.endpoint(elements[2])
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
tracing::info!("Waiting for remote {}...", client.path());
tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(());
}
r = client.wait_for_endpoints() => {
r?;
}
}
EngineConfig::Dynamic(client)
}
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => { Output::MistralRs => {
let Some(model_path) = model_path else { let Some(model_path) = model_path else {
...@@ -101,11 +145,14 @@ pub async fn run( ...@@ -101,11 +145,14 @@ pub async fn run(
match in_opt { match in_opt {
Input::Http => { Input::Http => {
crate::input::http::run(cancel_token.clone(), flags.http_port, engine_config).await?; crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?;
} }
Input::Text => { Input::Text => {
crate::input::text::run(cancel_token.clone(), engine_config).await?; crate::input::text::run(cancel_token.clone(), engine_config).await?;
} }
Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
}
} }
Ok(()) Ok(())
......
...@@ -103,11 +103,5 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()> ...@@ -103,11 +103,5 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
.chain(env::args().skip(non_flag_params)), .chain(env::args().skip(non_flag_params)),
)?; )?;
// etcd and nats addresses, from env vars ETCD_ENDPOINTS and NATS_SERVER with localhost tio::run(runtime, in_opt, out_opt, flags).await
// defaults
//let dt_config = triton_distributed::distributed::DistributedConfig::from_settings();
// Wraps the Runtime (which wraps two tokio runtimes) and adds etcd and nats clients
//let d_runtime = triton_distributed::DistributedRuntime::new(runtime, dt_config).await?;
tio::run(in_opt, out_opt, flags, runtime.primary_token()).await
} }
...@@ -15,12 +15,17 @@ ...@@ -15,12 +15,17 @@
use std::fmt; use std::fmt;
use crate::ENDPOINT_SCHEME;
pub enum Input { pub enum Input {
/// Run an OpenAI compatible HTTP server /// Run an OpenAI compatible HTTP server
Http, Http,
/// Read prompt from stdin /// Read prompt from stdin
Text, Text,
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
} }
impl TryFrom<&str> for Input { impl TryFrom<&str> for Input {
...@@ -30,6 +35,10 @@ impl TryFrom<&str> for Input { ...@@ -30,6 +35,10 @@ impl TryFrom<&str> for Input {
match s { match s {
"http" => Ok(Input::Http), "http" => Ok(Input::Http),
"text" => Ok(Input::Text), "text" => Ok(Input::Text),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Input::Endpoint(path.to_string()))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")), e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
} }
} }
...@@ -40,6 +49,7 @@ impl fmt::Display for Input { ...@@ -40,6 +49,7 @@ impl fmt::Display for Input {
let s = match self { let s = match self {
Input::Http => "http", Input::Http => "http",
Input::Text => "text", Input::Text => "text",
Input::Endpoint(path) => path,
}; };
write!(f, "{s}") write!(f, "{s}")
} }
...@@ -49,6 +59,9 @@ pub enum Output { ...@@ -49,6 +59,9 @@ pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response /// Accept un-preprocessed requests, echo the prompt back as the response
EchoFull, EchoFull,
/// Publish requests to a namespace/component/endpoint path.
Endpoint(String),
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle /// Run inference on a model in a GGUF file using mistralrs w/ candle
MistralRs, MistralRs,
...@@ -63,6 +76,12 @@ impl TryFrom<&str> for Output { ...@@ -63,6 +76,12 @@ impl TryFrom<&str> for Output {
"mistralrs" => Ok(Output::MistralRs), "mistralrs" => Ok(Output::MistralRs),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Output::Endpoint(path.to_string()))
}
e => Err(anyhow::anyhow!("Invalid out= option '{e}'")), e => Err(anyhow::anyhow!("Invalid out= option '{e}'")),
} }
} }
...@@ -75,6 +94,8 @@ impl fmt::Display for Output { ...@@ -75,6 +94,8 @@ impl fmt::Display for Output {
Output::MistralRs => "mistralrs", Output::MistralRs => "mistralrs",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::Endpoint(path) => path,
}; };
write!(f, "{s}") write!(f, "{s}")
} }
......
...@@ -76,4 +76,4 @@ insta = { version = "1.41", features = ["glob", "json", "redactions"]} ...@@ -76,4 +76,4 @@ insta = { version = "1.41", features = ["glob", "json", "redactions"]}
proptest = "1.5.0" proptest = "1.5.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
rstest = "0.18.2" rstest = "0.18.2"
tempfile = "3.17.1" tempfile = "3.17.1"
\ No newline at end of file
...@@ -17,11 +17,9 @@ use std::sync::Arc; ...@@ -17,11 +17,9 @@ use std::sync::Arc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Receiver;
use tracing as log;
use triton_distributed::{ use triton_distributed::{
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent}, transports::etcd::{KeyValue, WatchEvent},
DistributedRuntime, Result, DistributedRuntime, Result,
}; };
...@@ -51,7 +49,7 @@ pub struct ModelWatchState { ...@@ -51,7 +49,7 @@ pub struct ModelWatchState {
} }
pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) { pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) {
log::debug!("model watcher started"); tracing::debug!("model watcher started");
let mut events_rx = events_rx; let mut events_rx = events_rx;
...@@ -59,38 +57,38 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc ...@@ -59,38 +57,38 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
match event { match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await { WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
Ok(model_name) => { Ok(model_name) => {
log::info!("added chat model: {}", model_name); tracing::info!("added chat model: {}", model_name);
} }
Err(e) => { Err(e) => {
log::error!("error adding chat model: {}", e); tracing::error!("error adding chat model: {}", e);
// log::warn!( // tracing::warn!(
// "deleting offending key: {}", // "deleting offending key: {}",
// kv.key_str().unwrap_or_default() // kv.key_str().unwrap_or_default()
// ); // );
// if let Err(e) = kv_client.delete(kv.key(), None).await { // if let Err(e) = kv_client.delete(kv.key(), None).await {
// log::error!("failed to delete offending key: {}", e); // tracing::error!("failed to delete offending key: {}", e);
// } // }
} }
}, },
WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await { WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
Ok(model_name) => { Ok(model_name) => {
log::info!("removed chat model: {}", model_name); tracing::info!("removed chat model: {}", model_name);
} }
Err(e) => { Err(e) => {
log::error!("error removing chat model: {}", e); tracing::error!("error removing chat model: {}", e);
} }
}, },
} }
} }
log::debug!("model watcher stopped"); tracing::debug!("model watcher stopped");
} }
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> { async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
log::debug!("removing model"); tracing::debug!("removing model");
let key = kv.key_str()?; let key = kv.key_str()?;
log::debug!("key: {}", key); tracing::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix); let model_name = key.trim_start_matches(&state.prefix);
state.manager.remove_chat_completions_model(model_name)?; state.manager.remove_chat_completions_model(model_name)?;
...@@ -102,14 +100,15 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<Str ...@@ -102,14 +100,15 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<Str
// //
// If this method errors, for the near term, we will delete the offending key. // If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> { async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String> {
log::debug!("adding model"); tracing::debug!("adding model");
let key = kv.key_str()?; let key = kv.key_str()?;
log::debug!("key: {}", key); tracing::debug!("key: {}", key);
let model_name = key.trim_start_matches(&state.prefix); //let model_name = key.trim_start_matches(&state.prefix);
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?; let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
/*
// this means there is an entry in etcd that breaks the contract that the key // this means there is an entry in etcd that breaks the contract that the key
// in the models path must match the model name in the entry. // in the models path must match the model name in the entry.
if model_entry.name != model_name { if model_entry.name != model_name {
...@@ -119,6 +118,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String ...@@ -119,6 +118,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String
model_name model_name
); );
} }
*/
let client = state let client = state
.drt .drt
...@@ -130,9 +130,11 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String ...@@ -130,9 +130,11 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<String
let client = Arc::new(client); let client = Arc::new(client);
let model_name = model_entry.name.clone();
tracing::info!("New model registered: {model_name}");
state state
.manager .manager
.add_chat_completions_model(model_name, client)?; .add_chat_completions_model(&model_name, client)?;
Ok(model_name.to_string()) Ok(model_name.to_string())
} }
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
//! The `triton-llm` crate is a Rust library that provides a set of traits and types for building //! The `triton-llm` crate is a Rust library that provides a set of traits and types for building
//! distributed LLM inference solutions. //! distributed LLM inference solutions.
pub mod common;
pub mod engines; pub mod engines;
pub mod http; pub mod http;
pub mod kv_router; pub mod kv_router;
pub mod model_card;
pub mod protocols; pub mod protocols;
pub mod types; pub mod types;
pub mod model_card;
pub mod common;
\ No newline at end of file
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
// limitations under the License. // limitations under the License.
pub mod create; pub mod create;
pub mod model; pub mod model;
\ No newline at end of file
...@@ -15,15 +15,14 @@ ...@@ -15,15 +15,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::path::Path;
use std::fs;
use crate::model_card::model::ModelDeploymentCard; use crate::model_card::model::ModelDeploymentCard;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::fs;
use std::path::Path;
use crate::model_card::model::{ModelInfoType, TokenizerKind, PromptFormatterArtifact, File}; use crate::model_card::model::{File, ModelInfoType, PromptFormatterArtifact, TokenizerKind};
impl ModelDeploymentCard { impl ModelDeploymentCard {
/// Creates a ModelDeploymentCard from a local directory path. /// Creates a ModelDeploymentCard from a local directory path.
/// ///
/// Currently HuggingFace format is supported and following files are expected: /// Currently HuggingFace format is supported and following files are expected:
...@@ -57,7 +56,9 @@ impl ModelDeploymentCard { ...@@ -57,7 +56,9 @@ impl ModelDeploymentCard {
/// TODO: This will be implemented after nova-hub is integrated with the model-card /// TODO: This will be implemented after nova-hub is integrated with the model-card
/// TODO: Attempt to auto-detect model type and construct an MDC from a NGC repo /// TODO: Attempt to auto-detect model type and construct an MDC from a NGC repo
pub async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> { pub async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> {
Err(anyhow::anyhow!("ModelDeploymentCard::from_ngc_repo is not implemented")) Err(anyhow::anyhow!(
"ModelDeploymentCard::from_ngc_repo is not implemented"
))
} }
pub async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> { pub async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
...@@ -130,11 +131,12 @@ async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<File> { ...@@ -130,11 +131,12 @@ async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<File> {
} }
async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<String, File>> { async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<String, File>> {
let dir_entries = fs::read_dir(repo_id) let dir_entries =
.with_context(|| format!("Failed to read directory: {}", repo_id))?; fs::read_dir(repo_id).with_context(|| format!("Failed to read directory: {}", repo_id))?;
let mut found_files = HashMap::new(); let mut found_files = HashMap::new();
for entry in dir_entries { for entry in dir_entries {
let entry = entry.with_context(|| format!("Failed to read directory entry in {}", repo_id))?; let entry =
entry.with_context(|| format!("Failed to read directory entry in {}", repo_id))?;
let path = entry.path(); let path = entry.path();
let file_name = path let file_name = path
.file_name() .file_name()
...@@ -162,11 +164,17 @@ async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<St ...@@ -162,11 +164,17 @@ async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<St
fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> { fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref(); let path = path.as_ref();
if !path.exists() { if !path.exists() {
return Err(anyhow::anyhow!("Model path does not exist: {}", path.display())); return Err(anyhow::anyhow!(
"Model path does not exist: {}",
path.display()
));
} }
if !path.is_dir() { if !path.is_dir() {
return Err(anyhow::anyhow!("Model path is not a directory: {}", path.display())); return Err(anyhow::anyhow!(
"Model path is not a directory: {}",
path.display()
));
} }
Ok(()) Ok(())
} }
\ No newline at end of file
...@@ -25,9 +25,9 @@ ...@@ -25,9 +25,9 @@
//! - Prompt formatter settings (PromptFormatterArtifact) //! - Prompt formatter settings (PromptFormatterArtifact)
//! - Various metadata like revision, publish time, etc. //! - Various metadata like revision, publish time, etc.
use crate::protocols::TokenIdType;
use anyhow::Result; use anyhow::Result;
use either::Either; use either::Either;
use crate::protocols::TokenIdType;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
...@@ -196,7 +196,6 @@ impl ModelDeploymentCard { ...@@ -196,7 +196,6 @@ impl ModelDeploymentCard {
} }
} }
impl fmt::Display for ModelDeploymentCard { impl fmt::Display for ModelDeploymentCard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.slug()) write!(f, "{}", self.slug())
...@@ -256,7 +255,7 @@ struct HFConfigJsonFile { ...@@ -256,7 +255,7 @@ struct HFConfigJsonFile {
impl HFConfigJsonFile { impl HFConfigJsonFile {
async fn from_file(file: &File) -> Result<Arc<dyn ModelInfo>> { async fn from_file(file: &File) -> Result<Arc<dyn ModelInfo>> {
let contents = std::fs::read_to_string(&file)?; let contents = std::fs::read_to_string(file)?;
let config: Self = serde_json::from_str(&contents)?; let config: Self = serde_json::from_str(&contents)?;
Ok(Arc::new(config)) Ok(Arc::new(config))
} }
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use triton_llm::model_card::model::{ModelDeploymentCard, ModelInfoType, TokenizerKind, PromptFormatterArtifact};
use tempfile::tempdir; use tempfile::tempdir;
use triton_llm::model_card::model::{
ModelDeploymentCard, ModelInfoType, PromptFormatterArtifact, TokenizerKind,
};
#[tokio::test] #[tokio::test]
async fn test_model_info_from_hf_like_local_repo() { async fn test_model_info_from_hf_like_local_repo() {
...@@ -28,7 +30,6 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -28,7 +30,6 @@ async fn test_model_info_from_hf_like_local_repo() {
assert_eq!(info.vocab_size(), 128256); assert_eq!(info.vocab_size(), 128256);
} }
#[tokio::test] #[tokio::test]
async fn test_model_info_from_non_existent_local_repo() { async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist"; let path = "tests/data/sample-models/this-model-does-not-exist";
...@@ -67,4 +68,4 @@ async fn test_missing_required_files() { ...@@ -67,4 +68,4 @@ async fn test_missing_required_files() {
let err = result.unwrap_err().to_string(); let err = result.unwrap_err().to_string();
// Should fail because config.json is missing // Should fail because config.json is missing
assert!(err.contains("unable to extract")); assert!(err.contains("unable to extract"));
} }
\ No newline at end of file
...@@ -137,6 +137,10 @@ impl Component { ...@@ -137,6 +137,10 @@ impl Component {
format!("{}.events.{}", self.service_name(), name.as_ref()) format!("{}.events.{}", self.service_name(), name.as_ref())
} }
pub fn path(&self) -> String {
format!("{}/{}", self.namespace, self.name)
}
pub fn drt(&self) -> &DistributedRuntime { pub fn drt(&self) -> &DistributedRuntime {
&self.drt &self.drt
} }
...@@ -213,6 +217,10 @@ impl Endpoint { ...@@ -213,6 +217,10 @@ impl Endpoint {
&self.component &self.component
} }
pub fn path(&self) -> String {
format!("{}/{}", self.component.path(), self.name)
}
pub fn etcd_path(&self) -> String { pub fn etcd_path(&self) -> String {
format!("{}/{}", self.component.etcd_path(), self.name) format!("{}/{}", self.component.etcd_path(), self.name)
} }
......
...@@ -148,6 +148,11 @@ where ...@@ -148,6 +148,11 @@ where
}) })
} }
/// String identifying namepoint/component/endpoint
pub fn path(&self) -> String {
self.endpoint.path()
}
pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> { pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
&self.watch_rx &self.watch_rx
} }
......
...@@ -56,7 +56,7 @@ use tracing_subscriber::{filter::Directive, fmt}; ...@@ -56,7 +56,7 @@ use tracing_subscriber::{filter::Directive, fmt};
const FILTER_ENV: &str = "TRD_LOG"; const FILTER_ENV: &str = "TRD_LOG";
/// Default log level /// Default log level
const DEFAULT_FILTER_LEVEL: &str = "error"; const DEFAULT_FILTER_LEVEL: &str = "debug";
/// ENV used to set the path to the logging configuration file /// ENV used to set the path to the logging configuration file
const CONFIG_PATH_ENV: &str = "TRD_LOGGING_CONFIG_PATH"; const CONFIG_PATH_ENV: &str = "TRD_LOGGING_CONFIG_PATH";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment