Commit 089f8e1b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Download models from HF, smart model defaults (#126)



- Any engine can take the name of a Hugging Face repository. It will be downloaded before calling the engine.

- The default engine (previously always mistralrs) depends on what is compiled in.

- Text can be piped in and will result in a single run of the model.

All of those together mean if you build with `--features vllm` you can do this and it will download the model and run it with vllm, answer your question, and exit:
```
echo "What is the capital of Costa Rica?"  | dynamo-run Qwen/Qwen2.5-3B-Instruct
```
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 1d856345
......@@ -307,6 +307,31 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-lc-rs"
version = "1.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e4e8200b9a4a5801a769d50eeabc05670fec7e959a8cb7a63a93e4e519942ae"
dependencies = [
"aws-lc-sys",
"paste",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f9dd2e03ee80ca2822dd6ea431163d2ef259f2066a4d6ccaca6d9dcb386aa43"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"dunce",
"fs_extra",
"paste",
]
[[package]]
name = "axum"
version = "0.7.9"
......@@ -674,6 +699,7 @@ dependencies = [
"num_cpus",
"rand",
"reqwest",
"rustls",
"serde",
"serde_json",
"thiserror 1.0.69",
......@@ -1390,6 +1416,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dunce"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
[[package]]
name = "dyn-clone"
version = "1.0.19"
......@@ -1468,6 +1500,7 @@ dependencies = [
"async-openai",
"async-stream",
"async-trait",
"candle-hf-hub",
"clap",
"dialoguer",
"dynamo-llm",
......@@ -1887,6 +1920,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fuchsia-zircon"
version = "0.3.3"
......@@ -2371,6 +2410,7 @@ dependencies = [
"tokio",
"tokio-rustls",
"tower-service",
"webpki-roots",
]
[[package]]
......@@ -4537,6 +4577,7 @@ dependencies = [
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots",
"windows-registry",
]
......@@ -4647,6 +4688,7 @@ version = "0.23.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
dependencies = [
"aws-lc-rs",
"log",
"once_cell",
"ring",
......@@ -4705,6 +4747,7 @@ version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
......
......@@ -36,6 +36,7 @@ anyhow = "1"
async-openai = "0.27.2"
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
candle-hf-hub = { version = "0.3.3", default-features = false, features = ["tokio", "rustls-tls"] }
clap = { version = "4.5", features = ["derive", "env"] }
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
futures = { version = "0.3" }
......
......@@ -38,9 +38,11 @@ If you have an `HF_TOKEN` environment variable set, this will download Qwen2.5 3
dynamo-run Qwen/Qwen2.5-3B-Instruct
```
The parameter can be the ID of a HuggingFace repository (it will be downloaded), a GGUF file, or a folder containing safetensors, config.json, etc (a locally checked out HuggingFace repository).
## Download a model from Hugging Face
For example one of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
One of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
## Run
......@@ -349,3 +351,10 @@ DYN_TOKEN_ECHO_DELAY_MS=1 dynamo-run in=http out=echo_full
```
The default delay is 10ms, which produces approximately 100 tokens per second.
## Defaults
The input defaults to `in=text`.
The output will default to whatever engine you have compiled in (so depending on `--features`). If all features
are enabled at build time, then the default is currently `out=vllm`.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use candle_hf_hub::api::tokio::ApiBuilder;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 3] = [".gitattributes", "LICENSE", "README.md"];
/// Attempt to download a model from Hugging Face
/// Returns the directory it is in
pub async fn from_hf(name: &Path) -> anyhow::Result<PathBuf> {
let api = ApiBuilder::new().with_progress(true).build()?;
let repo = api.model(name.display().to_string());
let info = repo.info().await?;
let mut p = PathBuf::new();
for sib in info.siblings {
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
continue;
}
p = repo.get(&sib.rfilename).await?;
}
match p.parent() {
Some(p) => Ok(p.to_path_buf()),
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
fn is_image(s: &str) -> bool {
s.ends_with(".png") || s.ends_with("PNG") || s.ends_with(".jpg") || s.ends_with("JPG")
}
......@@ -31,7 +31,7 @@ use dynamo_runtime::{
};
use futures::StreamExt;
use std::{
io::{ErrorKind, Read, Write},
io::{ErrorKind, Write},
sync::Arc,
};
......@@ -40,12 +40,10 @@ use crate::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size.
const MAX_TOKENS: u32 = 8192;
/// Output of `isatty` if the fd is indeed a TTY
const IS_A_TTY: i32 = 1;
pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
single_prompt: Option<String>,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
let (service_name, engine, inspect_template): (
......@@ -75,7 +73,7 @@ pub async fn run(
service_name,
engine,
} => {
tracing::info!("Model: {service_name}");
tracing::debug!("Model: {service_name}");
(service_name, engine, false)
}
EngineConfig::StaticCore {
......@@ -101,12 +99,19 @@ pub async fn run(
.link(preprocessor.backward_edge())?
.link(frontend)?;
tracing::info!("Model: {service_name} with pre-processing");
tracing::debug!("Model: {service_name} with pre-processing");
(service_name, pipeline, true)
}
EngineConfig::None => unreachable!(),
};
main_loop(cancel_token, &service_name, engine, inspect_template).await
main_loop(
cancel_token,
&service_name,
engine,
single_prompt,
inspect_template,
)
.await
}
#[allow(deprecated)]
......@@ -114,20 +119,17 @@ async fn main_loop(
cancel_token: CancellationToken,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
mut initial_prompt: Option<String>,
_inspect_template: bool,
) -> anyhow::Result<()> {
if initial_prompt.is_none() {
tracing::info!("Ctrl-c to exit");
}
let theme = dialoguer::theme::ColorfulTheme::default();
let mut initial_prompt = if unsafe { libc::isatty(libc::STDIN_FILENO) == IS_A_TTY } {
None
} else {
// Something piped in, use that as initial prompt
let mut input = String::new();
std::io::stdin().read_to_string(&mut input).unwrap();
Some(input)
};
// Initial prompt is the pipe case: `echo "Hello" | dynamo-run ..`
// We run that single prompt and exit
let single = initial_prompt.is_some();
let mut history = dialoguer::BasicHistory::default();
let mut messages = vec![];
while !cancel_token.is_cancelled() {
......@@ -236,6 +238,10 @@ async fn main_loop(
},
);
messages.push(assistant_message);
if single {
break;
}
}
println!();
Ok(())
......
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Read;
#[cfg(any(feature = "vllm", feature = "sglang"))]
use std::{future::Future, pin::Pin};
......@@ -24,6 +25,7 @@ use dynamo_runtime::protocols::Endpoint;
mod flags;
pub use flags::Flags;
mod hub;
mod input;
#[cfg(any(feature = "vllm", feature = "sglang"))]
mod net;
......@@ -80,7 +82,7 @@ pub async fn run(
let cancel_token = runtime.primary_token();
// Turn relative paths into absolute paths
let model_path = flags
let mut model_path = flags
.model_path_pos
.clone()
.or(flags.model_path_flag.clone())
......@@ -91,8 +93,9 @@ pub async fn run(
Some(p)
}
});
// Serve the model under the name provided, or the name of the GGUF file or HF repo.
let model_name = flags
let mut model_name = flags
.model_name
.clone()
.or_else(|| {
......@@ -108,6 +111,18 @@ pub async fn run(
None
}
});
// If it's an HF repo download it
if let Some(inner_model_path) = model_path.as_ref() {
if !inner_model_path.exists() {
model_name = inner_model_path
.iter()
.last()
.map(|s| s.to_string_lossy().to_string());
model_path = Some(hub::from_hf(inner_model_path).await?);
}
}
// Load the model deployment card, if any
// Only used by some engines, so without those feature flags it's unused.
#[allow(unused_variables)]
......@@ -373,7 +388,19 @@ pub async fn run(
crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?;
}
Input::Text => {
crate::input::text::run(runtime.clone(), cancel_token.clone(), engine_config).await?;
crate::input::text::run(runtime.clone(), cancel_token.clone(), None, engine_config)
.await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(
runtime.clone(),
cancel_token.clone(),
Some(prompt),
engine_config,
)
.await?;
}
Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
......
......@@ -31,14 +31,6 @@ Example:
"#;
const DEFAULT_IN: Input = Input::Text;
#[cfg(feature = "mistralrs")]
const DEFAULT_OUT: Output = Output::MistralRs;
#[cfg(not(feature = "mistralrs"))]
const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "dyn";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core|pystr:<engine.py>|pytok:<engine.py>] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]";
......@@ -159,14 +151,18 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
non_flag_params += 1;
x
}
None => DEFAULT_IN,
None => Input::default(),
};
let out_opt = match out_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => DEFAULT_OUT,
None => {
let default_engine = Output::default(); // smart default based on feature flags
tracing::debug!("Using engine: {default_engine}");
default_engine
}
};
// Clap skips the first argument expecting it to be the binary name, so add it back
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt;
use std::{fmt, io::IsTerminal as _};
use crate::ENDPOINT_SCHEME;
......@@ -22,7 +22,10 @@ pub enum Input {
/// Run an OpenAI compatible HTTP server
Http,
/// Read prompt from stdin
/// Single prompt on stdin
Stdin,
/// Interactive chat
Text,
/// Pull requests from a namespace/component/endpoint path.
......@@ -41,6 +44,7 @@ impl TryFrom<&str> for Input {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"stdin" => Ok(Input::Stdin),
"none" => Ok(Input::None),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
......@@ -56,6 +60,7 @@ impl fmt::Display for Input {
let s = match self {
Input::Http => "http",
Input::Text => "text",
Input::Stdin => "stdin",
Input::Endpoint(path) => path,
Input::None => "none",
};
......@@ -63,6 +68,16 @@ impl fmt::Display for Input {
}
}
impl Default for Input {
fn default() -> Self {
if std::io::stdin().is_terminal() {
Input::Text
} else {
Input::Stdin
}
}
}
pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response
EchoFull,
......@@ -185,3 +200,36 @@ impl fmt::Display for Output {
write!(f, "{s}")
}
}
/// Returns the engine to use if user did not say on cmd line
/// Uses whatever was compiled in, with a priority ordering.
#[allow(unused_assignments, unused_mut)]
impl Default for Output {
fn default() -> Self {
// Default if no engines
let mut out = Output::EchoFull;
// Runs everywhere but needs local CUDA to build
#[cfg(feature = "mistralrs")]
{
out = Output::MistralRs;
}
#[cfg(feature = "llamacpp")]
{
out = Output::LlamaCpp;
}
#[cfg(feature = "sglang")]
{
out = Output::SgLang;
}
#[cfg(feature = "vllm")]
{
out = Output::Vllm;
}
out
}
}
......@@ -336,7 +336,7 @@ async fn start_vllm(
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
let mut line_parts = line.splitn(4, ' ');
let log_level = line_parts.next().unwrap_or_default();
let mut log_level = line_parts.next().unwrap_or_default();
// Skip date (0) and time (1). Print last (2) which is everything else.
let line = line_parts.nth(2).unwrap_or_default();
if line.starts_with("custom_op.py:68") {
......@@ -344,10 +344,14 @@ async fn start_vllm(
// custom_op.py:68] custom op <the op> enabled
continue;
}
if line.contains("ERROR") {
log_level = "ERROR";
}
match log_level {
"DEBUG" => tracing::debug!("VLLM: {line}"),
"INFO" => tracing::info!("VLLM: {line}"),
"INFO" => tracing::debug!("VLLM: {line}"), // VLLM is noisy
"WARNING" => tracing::warn!("VLLM: {line}"),
"ERROR" => tracing::error!("VLLM: {line}"),
level => tracing::info!("VLLM: {level} {line}"),
}
}
......
......@@ -153,7 +153,7 @@ impl Worker {
match &result {
Ok(_) => {
tracing::info!("Application shutdown successfully");
tracing::debug!("Application shutdown successfully");
}
Err(e) => {
tracing::error!("Application shutdown with error: {:?}", e);
......@@ -200,7 +200,7 @@ async fn signal_handler(cancel_token: CancellationToken) -> Result<()> {
tracing::info!("SIGTERM received, starting graceful shutdown");
},
_ = cancel_token.cancelled() => {
tracing::info!("CancellationToken triggered; shutting down");
tracing::debug!("CancellationToken triggered; shutting down");
},
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment