Commit 7ab5df5d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(tio): Defaults for in and out, support HF repos (#223)

You can now run an HF repo directly:
```
tio ~/llm_models/Llama-3.2-1B-Instruct/
```

or a GGUF
```
tio ~/llm_models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
```

Also cleanup kv_router so I can merge.
parent b90535aa
...@@ -59,6 +59,7 @@ server/ ...@@ -59,6 +59,7 @@ server/
# will have compiled files and executables # will have compiled files and executables
debug/ debug/
target/ target/
llm_engine.h
### Virtual Environment ### ### Virtual Environment ###
.venv/ .venv/
......
...@@ -9,15 +9,6 @@ Rust: ...@@ -9,15 +9,6 @@ Rust:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
``` ```
Get the NATS server from https://nats.io/download/ and run it:
```
nats-server -js --trace --store_dir $(mktemp -d)
```
Get etcd from https://github.com/etcd-io/etcd/releases and run it: `etcd`
These components are required but not yet used by tio. It's a journey, OK.
## Build ## Build
- CUDA: - CUDA:
...@@ -40,11 +31,11 @@ For example one of these should be fast and good quality on almost any machine: ...@@ -40,11 +31,11 @@ For example one of these should be fast and good quality on almost any machine:
*Text interface* *Text interface*
`./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf` `./target/release/tio Llama-3.2-1B-Instruct-Q4_K_M.gguf` or path to a Hugging Face repo checkout instead of the GGUF.
*HTTP interface* *HTTP interface*
`./target/release/tio in=http out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf` `./target/release/tio in=http --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
List the models: `curl localhost:8080/v1/models` List the models: `curl localhost:8080/v1/models`
......
...@@ -27,6 +27,15 @@ pub use opt::{Input, Output}; ...@@ -27,6 +27,15 @@ pub use opt::{Input, Output};
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
pub struct Flags { pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository.
/// For the `echo_full` engine omit the flag.
#[arg(index = 1)]
pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>`
#[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>,
/// HTTP port. `in=http` only /// HTTP port. `in=http` only
#[arg(long, default_value = "8080")] #[arg(long, default_value = "8080")]
pub http_port: u16, pub http_port: u16,
...@@ -34,12 +43,6 @@ pub struct Flags { ...@@ -34,12 +43,6 @@ pub struct Flags {
/// The name of the model we are serving /// The name of the model we are serving
#[arg(long)] #[arg(long)]
pub model_name: Option<String>, pub model_name: Option<String>,
/// Full path to the model. This differs by engine:
/// - mistralrs: File. GGUF.
/// - echo_full: Omit the flag.
#[arg(long)]
pub model_path: Option<PathBuf>,
} }
pub enum EngineConfig { pub enum EngineConfig {
...@@ -57,7 +60,10 @@ pub async fn run( ...@@ -57,7 +60,10 @@ pub async fn run(
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Turn relative paths into absolute paths // Turn relative paths into absolute paths
let model_path = flags.model_path.and_then(|p| p.canonicalize().ok()); let model_path = flags
.model_path_pos
.or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file. // Serve the model under the name provided, or the name of the GGUF file.
let model_name = flags.model_name.or_else(|| let model_name = flags.model_name.or_else(||
// "stem" means the filename without the extension. // "stem" means the filename without the extension.
...@@ -83,9 +89,6 @@ pub async fn run( ...@@ -83,9 +89,6 @@ pub async fn run(
let Some(model_path) = model_path else { let Some(model_path) = model_path else {
anyhow::bail!("out=mistralrs requires flag --model-path=<full-path-to-model-gguf>"); anyhow::bail!("out=mistralrs requires flag --model-path=<full-path-to-model-gguf>");
}; };
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file");
}
let Some(model_name) = model_name else { let Some(model_name) = model_name else {
unreachable!("We checked model_path earlier, and set model_name from model_path"); unreachable!("We checked model_path earlier, and set model_name from model_path");
}; };
......
...@@ -17,17 +17,28 @@ use std::env; ...@@ -17,17 +17,28 @@ use std::env;
use clap::Parser; use clap::Parser;
use tio::{Input, Output};
use triton_distributed::logging; use triton_distributed::logging;
const HELP: &str = r#" const HELP: &str = r#"
triton-llm service runner tio is a single binary that wires together the various inputs (http, text, network) and workers (network, engine), that runs the services. It is the simplest way to use triton-distributed locally.
Example: Example:
- cargo build --release --features mistralrs,cuda - cargo build --release --features mistralrs,cuda
- ./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf --model-name 'Llama-3.2-1B-Instruct' - cd target/release
- ./tio hf_checkouts/Llama-3.2-3B-Instruct/
- OR: ./tio Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#; "#;
const DEFAULT_IN: Input = Input::Text;
#[cfg(feature = "mistralrs")]
const DEFAULT_OUT: Output = Output::MistralRs;
#[cfg(not(feature = "mistralrs"))]
const DEFAULT_OUT: Output = Output::EchoFull;
const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]"; const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
...@@ -53,7 +64,8 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()> ...@@ -53,7 +64,8 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
} }
for arg in env::args().skip(1).take(2) { for arg in env::args().skip(1).take(2) {
let Some((in_out, val)) = arg.split_once('=') else { let Some((in_out, val)) = arg.split_once('=') else {
anyhow::bail!("Argument missing '='. {USAGE}"); // Probably we're defaulting in and/or out, and this is a flag
continue;
}; };
match in_out { match in_out {
"in" => { "in" => {
...@@ -67,13 +79,29 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()> ...@@ -67,13 +79,29 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
} }
} }
} }
let (Some(in_opt), Some(out_opt)) = (in_opt, out_opt) else { let mut non_flag_params = 1; // binary name
anyhow::bail!("Missing 'in' or 'out'. {USAGE}"); let in_opt = match in_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => DEFAULT_IN,
};
let out_opt = match out_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => DEFAULT_OUT,
}; };
// Clap skips the first argument expecting it to be the binary name, so add it back // Clap skips the first argument expecting it to be the binary name, so add it back
let nio_flags = // Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag.
tio::Flags::try_parse_from(["tio".to_string()].into_iter().chain(env::args().skip(3)))?; let flags = tio::Flags::try_parse_from(
["tio".to_string()]
.into_iter()
.chain(env::args().skip(non_flag_params)),
)?;
// etcd and nats addresses, from env vars ETCD_ENDPOINTS and NATS_SERVER with localhost // etcd and nats addresses, from env vars ETCD_ENDPOINTS and NATS_SERVER with localhost
// defaults // defaults
...@@ -81,5 +109,5 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()> ...@@ -81,5 +109,5 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
// Wraps the Runtime (which wraps two tokio runtimes) and adds etcd and nats clients // Wraps the Runtime (which wraps two tokio runtimes) and adds etcd and nats clients
//let d_runtime = triton_distributed::DistributedRuntime::new(runtime, dt_config).await?; //let d_runtime = triton_distributed::DistributedRuntime::new(runtime, dt_config).await?;
tio::run(in_opt, out_opt, nio_flags, runtime.primary_token()).await tio::run(in_opt, out_opt, flags, runtime.primary_token()).await
} }
...@@ -22,8 +22,8 @@ use indexmap::IndexMap; ...@@ -22,8 +22,8 @@ use indexmap::IndexMap;
use mistralrs::{ use mistralrs::{
Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, DeviceMapSetting, Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder, GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalRequest, PagedAttentionConfig, Pipeline, Request, RequestMessage, ResponseOk, ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
SamplingParams, SchedulerConfig, TokenSource, Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource,
}; };
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
...@@ -68,28 +68,45 @@ struct MistralRsEngine { ...@@ -68,28 +68,45 @@ struct MistralRsEngine {
impl MistralRsEngine { impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> { async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let Some(model_filename) = model_path.file_name() else { let loader = if model_path.is_file() {
pipeline_error::bail!("Missing filename in model path"); // Load from a GGUF
}; let Some(model_filename) = model_path.file_name() else {
let Some(model_dir) = model_path.parent() else { pipeline_error::bail!("Missing filename in model path");
pipeline_error::bail!("Invalid model path"); };
}; let Some(model_dir) = model_path.parent() else {
pipeline_error::bail!("Invalid model path");
};
// Select a Mistral model GGUFLoaderBuilder::new(
// We do not use any files from HF servers here, and instead load the None,
// chat template from the specified file, and the tokenizer and model from a None,
// local GGUF file at the path `.` model_dir.display().to_string(),
let loader = GGUFLoaderBuilder::new( vec![model_filename.to_string_lossy().into_owned()],
None, GGUFSpecificConfig {
None, prompt_chunksize: None,
model_dir.display().to_string(), topology: None,
vec![model_filename.to_string_lossy().into_owned()], },
GGUFSpecificConfig { )
prompt_chunksize: None, .build()
topology: None, } else {
}, // Load from a HF repo dir
) NormalLoaderBuilder::new(
.build(); NormalSpecificConfig {
use_flash_attn: false,
prompt_chunksize: None,
topology: None,
organization: Default::default(),
write_uqff: None,
from_uqff: None,
imatrix: None,
calibration_file: None,
},
None,
None,
Some(model_path.display().to_string()),
)
.build(None)?
};
// Paged attention requires cuda // Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") { let paged_attention_config = if cfg!(feature = "cuda") {
......
...@@ -45,6 +45,7 @@ pub struct KvRouter { ...@@ -45,6 +45,7 @@ pub struct KvRouter {
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
#[allow(dead_code)]
scheduler: KvScheduler, scheduler: KvScheduler,
indexer: KvIndexer, indexer: KvIndexer,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use std::cmp::min; use std::cmp::min;
use tracing as log;
use uuid::Uuid; use uuid::Uuid;
...@@ -29,17 +28,11 @@ pub enum KvSchedulerError { ...@@ -29,17 +28,11 @@ pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")] #[error("no endpoints aviailable to route work")]
NoEndpoints, NoEndpoints,
#[error("endpoints existed, but no valid routes were found")]
NoRoutes,
#[error("all workers busy")] #[error("all workers busy")]
AllWorkersBusy, AllWorkersBusy,
#[error("endpoint subscriber shutdown")] #[error("endpoint subscriber shutdown")]
SubscriberShutdown, SubscriberShutdown,
#[error("scheduler offline")]
SchedulerOffline,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -81,7 +74,7 @@ pub struct SchedulingRequest { ...@@ -81,7 +74,7 @@ pub struct SchedulingRequest {
impl SchedulingRequest { impl SchedulingRequest {
pub fn respond(self, worker_id: String) { pub fn respond(self, worker_id: String) {
if self.resp_tx.send(worker_id).is_err() { if self.resp_tx.send(worker_id).is_err() {
log::trace!("failed to send response to requestor"); tracing::trace!("failed to send response to requestor");
} }
} }
} }
...@@ -96,7 +89,7 @@ impl KvScheduler { ...@@ -96,7 +89,7 @@ impl KvScheduler {
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx; let mut endpoints_rx = endpoints_rx;
log::trace!("awaiting the start of the background endpoint subscriber"); tracing::trace!("awaiting the start of the background endpoint subscriber");
let mut endpoints = match endpoints_rx.recv().await { let mut endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints, Some(endpoints) => endpoints,
None => { None => {
...@@ -106,12 +99,12 @@ impl KvScheduler { ...@@ -106,12 +99,12 @@ impl KvScheduler {
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
log::debug!("scheduler starting"); tracing::debug!("scheduler starting");
// Background task to handle scheduling requests // Background task to handle scheduling requests
tokio::spawn(async move { tokio::spawn(async move {
let mut request: SchedulingRequest; let mut request: SchedulingRequest;
let mut request_rx = request_rx; let mut request_rx = request_rx;
log::debug!("scheduler background task started"); tracing::debug!("scheduler background task started");
'outer: loop { 'outer: loop {
request = tokio::select! { request = tokio::select! {
...@@ -120,11 +113,11 @@ impl KvScheduler { ...@@ -120,11 +113,11 @@ impl KvScheduler {
new_request = request_rx.recv() => { new_request = request_rx.recv() => {
match new_request { match new_request {
Some(new_request) => { Some(new_request) => {
log::trace!("received request to be scheduled"); tracing::trace!("received request to be scheduled");
new_request new_request
}, },
None => { None => {
log::trace!("scheduler shutdown"); tracing::trace!("scheduler shutdown");
break 'outer; break 'outer;
} }
} }
...@@ -133,18 +126,18 @@ impl KvScheduler { ...@@ -133,18 +126,18 @@ impl KvScheduler {
new_endpoints = endpoints_rx.recv() => { new_endpoints = endpoints_rx.recv() => {
match new_endpoints { match new_endpoints {
Some(new_endpoints) => { Some(new_endpoints) => {
log::trace!("updated endpoints"); tracing::trace!("updated endpoints");
endpoints = new_endpoints; endpoints = new_endpoints;
continue 'outer; continue 'outer;
} }
None => { None => {
log::trace!("endpoint subscriber shutdown"); tracing::trace!("endpoint subscriber shutdown");
break 'outer; break 'outer;
} }
} }
} }
}; };
log::debug!("selected"); tracing::debug!("selected");
loop { loop {
match select_worker(endpoints.borrow_mut(), &request) { match select_worker(endpoints.borrow_mut(), &request) {
Ok(worker_id) => { Ok(worker_id) => {
...@@ -152,29 +145,30 @@ impl KvScheduler { ...@@ -152,29 +145,30 @@ impl KvScheduler {
continue 'outer; continue 'outer;
} }
Err(KvSchedulerError::AllWorkersBusy) => { Err(KvSchedulerError::AllWorkersBusy) => {
log::trace!("all workers busy; waiting for more capacity"); tracing::trace!("all workers busy; waiting for more capacity");
endpoints = match endpoints_rx.recv().await { endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints, Some(endpoints) => endpoints,
None => { None => {
log::trace!("endpoint subscriber shutdown"); tracing::trace!("endpoint subscriber shutdown");
break 'outer; break 'outer;
} }
}; };
} }
Err(e) => { Err(e) => {
log::error!("error scheduling request: {:?}", e); tracing::error!("error scheduling request: {:?}", e);
break 'outer; break 'outer;
} }
} }
} }
} }
log::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
}); });
Ok(KvScheduler { request_tx }) Ok(KvScheduler { request_tx })
} }
#[allow(dead_code)]
pub async fn schedule( pub async fn schedule(
&self, &self,
overlap: OverlapScores, overlap: OverlapScores,
...@@ -186,17 +180,17 @@ impl KvScheduler { ...@@ -186,17 +180,17 @@ impl KvScheduler {
overlap, overlap,
resp_tx, resp_tx,
}; };
log::debug!("before sending request"); tracing::debug!("before sending request");
self.request_tx self.request_tx
.send(request) .send(request)
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after sending request"); tracing::debug!("after sending request");
let res = resp_rx let res = resp_rx
.await .await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?; .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after receiving response"); tracing::debug!("after receiving response");
Ok(res) Ok(res)
} }
} }
...@@ -247,7 +241,7 @@ pub fn select_worker( ...@@ -247,7 +241,7 @@ pub fn select_worker(
+ (1.0 - alpha) * normalized_new_tokens + (1.0 - alpha) * normalized_new_tokens
+ gamma * request_load_ratio; + gamma * request_load_ratio;
log::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}", tracing::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
worker_id, worker_id,
load_deviation, load_deviation,
normalized_new_tokens, normalized_new_tokens,
...@@ -270,7 +264,7 @@ pub fn select_worker( ...@@ -270,7 +264,7 @@ pub fn select_worker(
match best_index { match best_index {
Some(i) => { Some(i) => {
log::info!( tracing::info!(
"selected worker: {}; cost: {}", "selected worker: {}; cost: {}",
workers.endpoints[i].subject, workers.endpoints[i].subject,
best_cost best_cost
...@@ -278,7 +272,7 @@ pub fn select_worker( ...@@ -278,7 +272,7 @@ pub fn select_worker(
Ok(workers.endpoints[i].subject.clone()) Ok(workers.endpoints[i].subject.clone())
} }
None => { None => {
log::debug!("all workers busy"); tracing::debug!("all workers busy");
Err(KvSchedulerError::AllWorkersBusy) Err(KvSchedulerError::AllWorkersBusy)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment