Commit 7ab5df5d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(tio): Defaults for in and out, support HF repos (#223)

You can now run an HF repo directly:
```
tio ~/llm_models/Llama-3.2-1B-Instruct/
```

or a GGUF
```
tio ~/llm_models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
```

Also cleanup kv_router so I can merge.
parent b90535aa
......@@ -59,6 +59,7 @@ server/
# will have compiled files and executables
debug/
target/
llm_engine.h
### Virtual Environment ###
.venv/
......
......@@ -9,15 +9,6 @@ Rust:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
Get the NATS server from https://nats.io/download/ and run it:
```
nats-server -js --trace --store_dir $(mktemp -d)
```
Get etcd from https://github.com/etcd-io/etcd/releases and run it: `etcd`
These components are required but not yet used by tio. It's a journey, OK.
## Build
- CUDA:
......@@ -40,11 +31,11 @@ For example one of these should be fast and good quality on almost any machine:
*Text interface*
`./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
`./target/release/tio Llama-3.2-1B-Instruct-Q4_K_M.gguf` or path to a Hugging Face repo checkout instead of the GGUF.
*HTTP interface*
`./target/release/tio in=http out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
`./target/release/tio in=http --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
List the models: `curl localhost:8080/v1/models`
......
......@@ -27,6 +27,15 @@ pub use opt::{Input, Output};
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository.
/// For the `echo_full` engine omit the flag.
#[arg(index = 1)]
pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>`
#[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>,
/// HTTP port. `in=http` only
#[arg(long, default_value = "8080")]
pub http_port: u16,
......@@ -34,12 +43,6 @@ pub struct Flags {
/// The name of the model we are serving
#[arg(long)]
pub model_name: Option<String>,
/// Full path to the model. This differs by engine:
/// - mistralrs: File. GGUF.
/// - echo_full: Omit the flag.
#[arg(long)]
pub model_path: Option<PathBuf>,
}
pub enum EngineConfig {
......@@ -57,7 +60,10 @@ pub async fn run(
cancel_token: CancellationToken,
) -> anyhow::Result<()> {
// Turn relative paths into absolute paths
let model_path = flags.model_path.and_then(|p| p.canonicalize().ok());
let model_path = flags
.model_path_pos
.or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file.
let model_name = flags.model_name.or_else(||
// "stem" means the filename without the extension.
......@@ -83,9 +89,6 @@ pub async fn run(
let Some(model_path) = model_path else {
anyhow::bail!("out=mistralrs requires flag --model-path=<full-path-to-model-gguf>");
};
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file");
}
let Some(model_name) = model_name else {
unreachable!("We checked model_path earlier, and set model_name from model_path");
};
......
......@@ -17,17 +17,28 @@ use std::env;
use clap::Parser;
use tio::{Input, Output};
use triton_distributed::logging;
const HELP: &str = r#"
triton-llm service runner
tio is a single binary that wires together the various inputs (http, text, network) and workers (network, engine), that runs the services. It is the simplest way to use triton-distributed locally.
Example:
- cargo build --release --features mistralrs,cuda
- ./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf --model-name 'Llama-3.2-1B-Instruct'
- cd target/release
- ./tio hf_checkouts/Llama-3.2-3B-Instruct/
- OR: ./tio Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#;
const DEFAULT_IN: Input = Input::Text;
#[cfg(feature = "mistralrs")]
const DEFAULT_OUT: Output = Output::MistralRs;
#[cfg(not(feature = "mistralrs"))]
const DEFAULT_OUT: Output = Output::EchoFull;
const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]";
fn main() -> anyhow::Result<()> {
......@@ -53,7 +64,8 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
}
for arg in env::args().skip(1).take(2) {
let Some((in_out, val)) = arg.split_once('=') else {
anyhow::bail!("Argument missing '='. {USAGE}");
// Probably we're defaulting in and/or out, and this is a flag
continue;
};
match in_out {
"in" => {
......@@ -67,13 +79,29 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
}
}
}
let (Some(in_opt), Some(out_opt)) = (in_opt, out_opt) else {
anyhow::bail!("Missing 'in' or 'out'. {USAGE}");
let mut non_flag_params = 1; // binary name
let in_opt = match in_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => DEFAULT_IN,
};
let out_opt = match out_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => DEFAULT_OUT,
};
// Clap skips the first argument expecting it to be the binary name, so add it back
let nio_flags =
tio::Flags::try_parse_from(["tio".to_string()].into_iter().chain(env::args().skip(3)))?;
// Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag.
let flags = tio::Flags::try_parse_from(
["tio".to_string()]
.into_iter()
.chain(env::args().skip(non_flag_params)),
)?;
// etcd and nats addresses, from env vars ETCD_ENDPOINTS and NATS_SERVER with localhost
// defaults
......@@ -81,5 +109,5 @@ async fn tio_wrapper(runtime: triton_distributed::Runtime) -> anyhow::Result<()>
// Wraps the Runtime (which wraps two tokio runtimes) and adds etcd and nats clients
//let d_runtime = triton_distributed::DistributedRuntime::new(runtime, dt_config).await?;
tio::run(in_opt, out_opt, nio_flags, runtime.primary_token()).await
tio::run(in_opt, out_opt, flags, runtime.primary_token()).await
}
......@@ -22,8 +22,8 @@ use indexmap::IndexMap;
use mistralrs::{
Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalRequest, PagedAttentionConfig, Pipeline, Request, RequestMessage, ResponseOk,
SamplingParams, SchedulerConfig, TokenSource,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource,
};
use tokio::sync::mpsc::channel;
......@@ -68,28 +68,45 @@ struct MistralRsEngine {
impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let Some(model_filename) = model_path.file_name() else {
pipeline_error::bail!("Missing filename in model path");
};
let Some(model_dir) = model_path.parent() else {
pipeline_error::bail!("Invalid model path");
};
let loader = if model_path.is_file() {
// Load from a GGUF
let Some(model_filename) = model_path.file_name() else {
pipeline_error::bail!("Missing filename in model path");
};
let Some(model_dir) = model_path.parent() else {
pipeline_error::bail!("Invalid model path");
};
// Select a Mistral model
// We do not use any files from HF servers here, and instead load the
// chat template from the specified file, and the tokenizer and model from a
// local GGUF file at the path `.`
let loader = GGUFLoaderBuilder::new(
None,
None,
model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()],
GGUFSpecificConfig {
prompt_chunksize: None,
topology: None,
},
)
.build();
GGUFLoaderBuilder::new(
None,
None,
model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()],
GGUFSpecificConfig {
prompt_chunksize: None,
topology: None,
},
)
.build()
} else {
// Load from a HF repo dir
NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn: false,
prompt_chunksize: None,
topology: None,
organization: Default::default(),
write_uqff: None,
from_uqff: None,
imatrix: None,
calibration_file: None,
},
None,
None,
Some(model_path.display().to_string()),
)
.build(None)?
};
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") {
......
......@@ -45,6 +45,7 @@ pub struct KvRouter {
cancellation_token: CancellationToken,
#[allow(dead_code)]
scheduler: KvScheduler,
indexer: KvIndexer,
......
......@@ -16,7 +16,6 @@
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
use tracing as log;
use uuid::Uuid;
......@@ -29,17 +28,11 @@ pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
NoEndpoints,
#[error("endpoints existed, but no valid routes were found")]
NoRoutes,
#[error("all workers busy")]
AllWorkersBusy,
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("scheduler offline")]
SchedulerOffline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -81,7 +74,7 @@ pub struct SchedulingRequest {
impl SchedulingRequest {
pub fn respond(self, worker_id: String) {
if self.resp_tx.send(worker_id).is_err() {
log::trace!("failed to send response to requestor");
tracing::trace!("failed to send response to requestor");
}
}
}
......@@ -96,7 +89,7 @@ impl KvScheduler {
) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx;
log::trace!("awaiting the start of the background endpoint subscriber");
tracing::trace!("awaiting the start of the background endpoint subscriber");
let mut endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints,
None => {
......@@ -106,12 +99,12 @@ impl KvScheduler {
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
log::debug!("scheduler starting");
tracing::debug!("scheduler starting");
// Background task to handle scheduling requests
tokio::spawn(async move {
let mut request: SchedulingRequest;
let mut request_rx = request_rx;
log::debug!("scheduler background task started");
tracing::debug!("scheduler background task started");
'outer: loop {
request = tokio::select! {
......@@ -120,11 +113,11 @@ impl KvScheduler {
new_request = request_rx.recv() => {
match new_request {
Some(new_request) => {
log::trace!("received request to be scheduled");
tracing::trace!("received request to be scheduled");
new_request
},
None => {
log::trace!("scheduler shutdown");
tracing::trace!("scheduler shutdown");
break 'outer;
}
}
......@@ -133,18 +126,18 @@ impl KvScheduler {
new_endpoints = endpoints_rx.recv() => {
match new_endpoints {
Some(new_endpoints) => {
log::trace!("updated endpoints");
tracing::trace!("updated endpoints");
endpoints = new_endpoints;
continue 'outer;
}
None => {
log::trace!("endpoint subscriber shutdown");
tracing::trace!("endpoint subscriber shutdown");
break 'outer;
}
}
}
};
log::debug!("selected");
tracing::debug!("selected");
loop {
match select_worker(endpoints.borrow_mut(), &request) {
Ok(worker_id) => {
......@@ -152,29 +145,30 @@ impl KvScheduler {
continue 'outer;
}
Err(KvSchedulerError::AllWorkersBusy) => {
log::trace!("all workers busy; waiting for more capacity");
tracing::trace!("all workers busy; waiting for more capacity");
endpoints = match endpoints_rx.recv().await {
Some(endpoints) => endpoints,
None => {
log::trace!("endpoint subscriber shutdown");
tracing::trace!("endpoint subscriber shutdown");
break 'outer;
}
};
}
Err(e) => {
log::error!("error scheduling request: {:?}", e);
tracing::error!("error scheduling request: {:?}", e);
break 'outer;
}
}
}
}
log::trace!("background endpoint subscriber shutting down");
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler { request_tx })
}
#[allow(dead_code)]
pub async fn schedule(
&self,
overlap: OverlapScores,
......@@ -186,17 +180,17 @@ impl KvScheduler {
overlap,
resp_tx,
};
log::debug!("before sending request");
tracing::debug!("before sending request");
self.request_tx
.send(request)
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after sending request");
tracing::debug!("after sending request");
let res = resp_rx
.await
.map_err(|_| KvSchedulerError::SubscriberShutdown)?;
log::debug!("after receiving response");
tracing::debug!("after receiving response");
Ok(res)
}
}
......@@ -247,7 +241,7 @@ pub fn select_worker(
+ (1.0 - alpha) * normalized_new_tokens
+ gamma * request_load_ratio;
log::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
tracing::debug!("worker: {}; load_deviation: {}; normalized new blocks: {}; request_load_ratio: {} cost: {}",
worker_id,
load_deviation,
normalized_new_tokens,
......@@ -270,7 +264,7 @@ pub fn select_worker(
match best_index {
Some(i) => {
log::info!(
tracing::info!(
"selected worker: {}; cost: {}",
workers.endpoints[i].subject,
best_cost
......@@ -278,7 +272,7 @@ pub fn select_worker(
Ok(workers.endpoints[i].subject.clone())
}
None => {
log::debug!("all workers busy");
tracing::debug!("all workers busy");
Err(KvSchedulerError::AllWorkersBusy)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment