Unverified Commit 92f06b0e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(dynamo-run): Refactor to library (#1687)

Move much of what was in the `dynamo-run` crate into `dynamo-llm` so that everyone can use it.

Example usage:

1. Create a `LocalModel`:

```
    let local_model = LocalModelBuilder::default()
	.model_path("Qwen/Qwen3-0.6B")
	.http_port(8080)
	.build().await?;
```

2. Make an engine:

```
    let engine_config = EngineConfig::StaticFull {
	engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
	model: Box::new(local_model),
    };
```

3. Connect it to an input and run it

```
    dynamo_llm::entrypoint::input::run_input(Input::Http, runtime, engine_config).await?;
```

For https://github.com/ai-dynamo/dynamo/issues/1647

Code Rabbit summary, thanks:
  * Introduced a flexible builder pattern for local model configuration, allowing advanced customization and easier initialization.
  * Added new input modes and unified input handling, supporting interactive chat, HTTP server, batch file, and distributed endpoint modes.
  * Centralized engine configuration and routing, enabling more extensible and maintainable engine management.
  * Simplified and modularized the codebase by moving input and engine logic into dedicated modules.
  * Replaced direct model construction with an asynchronous builder for improved clarity and extensibility.
  * Streamlined configuration and validation for flags and router settings.
  * Added validation to prevent incompatible input and output combinations in endpoint and dynamic modes.
parent 3b62692f
......@@ -1619,6 +1619,7 @@ dependencies = [
"cudarc 0.16.2",
"derive-getters",
"derive_builder",
"dialoguer",
"dynamo-runtime",
"either",
"erased-serde",
......@@ -1627,6 +1628,7 @@ dependencies = [
"galil-seiferas",
"ggus",
"hf-hub",
"humantime",
"insta",
"itertools 0.14.0",
"lazy_static",
......@@ -1677,14 +1679,12 @@ dependencies = [
"async-stream",
"async-trait",
"clap",
"dialoguer",
"dynamo-engine-llamacpp",
"dynamo-engine-mistralrs",
"dynamo-llm",
"dynamo-runtime",
"futures",
"futures-util",
"humantime",
"libc",
"regex",
"serde",
......
......@@ -47,7 +47,7 @@ struct Args {
/// Block size for the router
#[arg(long)]
block_size: usize,
block_size: u32,
}
fn main() -> Result<()> {
......@@ -88,7 +88,7 @@ impl WorkerSelector for CustomWorkerSelector {
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
// customize logic here
// F12 into [DefaultWorkerSelector] to see the original logic
......
......@@ -34,7 +34,6 @@ anyhow = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
humantime = { workspace = true }
libc = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
......@@ -47,7 +46,6 @@ uuid = { workspace = true }
async-openai = { workspace = true }
clap = { version = "4.5", features = ["derive", "env"] }
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
futures-util = { version = "0.3" }
regex = "1"
......
......@@ -17,9 +17,13 @@ use std::collections::HashMap;
use std::path::PathBuf;
use clap::ValueEnum;
use dynamo_llm::entrypoint::RouterConfig;
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
use crate::Output;
/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
......@@ -125,11 +129,11 @@ pub struct Flags {
/// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
#[arg(long)]
pub context_length: Option<usize>,
pub context_length: Option<u32>,
/// KV cache block size (vllm only)
#[arg(long)]
pub kv_cache_block_size: Option<usize>,
pub kv_cache_block_size: Option<u32>,
/// Additional engine-specific arguments from a JSON file.
/// Contains a mapping of parameter names to values.
......@@ -154,68 +158,65 @@ pub struct Flags {
}
impl Flags {
/// Get KV router configuration
pub fn kv_router_config(&self) -> KvRouterConfig {
/// For each Output variant, check if it would be able to run.
/// This takes validation out of the main engine creation path.
pub fn validate(&self, local_model: &LocalModel, out_opt: &Output) -> anyhow::Result<()> {
match out_opt {
Output::Dynamic => {
if self.context_length.is_some() {
anyhow::bail!("'--context-length' flag should only be used on the worker node, not on the ingress");
}
if self.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
}
Output::EchoFull => {}
Output::EchoCore => {
if !local_model.card().has_tokenizer() {
anyhow::bail!(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
Output::SgLang => {
if !local_model.path().is_dir() {
// TODO GGUF support for sglang: https://github.com/ai-dynamo/dynamo/issues/572
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
}
Output::Vllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
Output::Trtllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
if !local_model.path().is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
}
}
Ok(())
}
pub fn router_config(&self) -> RouterConfig {
RouterConfig::new(
self.router_mode.into(),
KvRouterConfig::new(
self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight,
self.kv_waiting_requests_weight,
),
)
}
/// Convert the flags back to a command line. Including only the non-null values, but
/// include the defaults. Includes the canonicalized model path and normalized model name.
///
/// Used to pass arguments to python engines via `pystr` and `pytok`.
pub fn as_vec(&self, path: &str, name: &str) -> Vec<String> {
let mut out = vec![
"--model-path".to_string(),
path.to_string(),
"--model-name".to_string(),
name.to_string(),
"--http-port".to_string(),
self.http_port.to_string(),
// Default 1
"--tensor-parallel-size".to_string(),
self.tensor_parallel_size.to_string(),
// Default 0
"--base-gpu-id".to_string(),
self.base_gpu_id.to_string(),
// Default 1
"--num-nodes".to_string(),
self.num_nodes.to_string(),
// Default 0
"--node-rank".to_string(),
self.node_rank.to_string(),
];
if let Some(model_config_path) = self.model_config.as_ref() {
out.push("--model-config".to_string());
out.push(model_config_path.display().to_string());
}
if let Some(leader) = self.leader_addr.as_ref() {
out.push("--leader-addr".to_string());
out.push(leader.to_string());
}
if let Some(extra_engine_args) = self.extra_engine_args.as_ref() {
out.push("--extra-engine-args".to_string());
out.push(extra_engine_args.display().to_string());
}
if let Some(weight) = self.kv_overlap_score_weight {
out.push("--kv-overlap-score-weight".to_string());
out.push(weight.to_string());
}
if let Some(weight) = self.kv_gpu_cache_usage_weight {
out.push("--kv-gpu-cache-usage-weight".to_string());
out.push(weight.to_string());
}
if let Some(weight) = self.kv_waiting_requests_weight {
out.push("--kv-waiting-requests-weight".to_string());
out.push(weight.to_string());
}
out.extend(self.last.clone());
out
}
/// Load extra engine arguments from a JSON file
/// Returns a HashMap of parameter names to values
pub fn load_extra_engine_args(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod batch;
mod common;
pub mod endpoint;
pub mod http;
pub mod text;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc, time::Duration};
use anyhow::Context;
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel};
use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::{CancellationToken, DistributedRuntime};
use anyhow::Context as _;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::CancellationToken;
mod flags;
pub use flags::Flags;
mod input;
mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output};
pub use opt::Output;
mod subprocess;
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
/// Default size of a KV cache block. Override with --kv-cache-block-size
const DEFAULT_KV_CACHE_BLOCK_SIZE: usize = 16;
pub enum EngineConfig {
/// Remote networked engines
Dynamic,
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
engine: ExecutionContext,
model: Box<LocalModel>,
},
}
fn is_in_dynamic(in_opt: &Input) -> bool {
matches!(in_opt, Input::Endpoint(_))
}
fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Dynamic))
}
pub async fn run(
runtime: dynamo_runtime::Runtime,
in_opt: Input,
out_opt: Option<Output>,
flags: Flags,
) -> anyhow::Result<()> {
if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) {
anyhow::bail!("Cannot use endpoint for both in and out");
}
//
// Configure
//
let cancel_token = runtime.primary_token();
let maybe_path = flags
let mut builder = LocalModelBuilder::default();
builder
.model_path(
flags
.model_path_pos
.clone()
.or(flags.model_path_flag.clone());
let mut local_model: LocalModel = if is_out_dynamic(&out_opt) {
// If output is dynamic we are ingress and don't have a local model, but making an
// empty one cleans up the code.
Default::default()
} else {
// All other output types have a local model
match &maybe_path {
Some(model_path) => {
LocalModel::prepare(
model_path.to_str().context("Invalid UTF-8 in model path")?,
flags.model_config.as_deref(),
flags.model_name.clone(),
.or(flags.model_path_flag.clone()),
)
.await?
}
None => {
// echo_full engine doesn't need a path
match &flags.model_name {
Some(name) => LocalModel::with_name_only(name),
None => Default::default(),
}
}
}
.model_name(flags.model_name.clone())
.kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length)
.http_port(flags.http_port)
.router_config(flags.router_config())
.request_template(flags.request_template.clone());
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
if let Input::Endpoint(path) = &in_opt {
builder.endpoint_id(path.parse().with_context(|| path.clone())?);
};
// Only set if user provides. Usually loaded from tokenizer_config.json
if let Some(context_length) = flags.context_length {
local_model.set_context_length(context_length);
}
// Always set, there is no engine provided default
local_model.set_kv_cache_block_size(
flags
.kv_cache_block_size
.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE),
);
let local_model = builder.build().await?;
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
//
// Create an engine
//
let template = if let Some(path) = flags.request_template.as_ref() {
let template = RequestTemplate::load(path)?;
tracing::debug!("Using request template: {template:?}");
Some(template)
} else {
None
};
let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
print_cuda(&out_opt);
// We may need it later
let card = local_model.card().clone();
// Now that we know the output we're targeting, check if we expect it to work
flags.validate(&local_model, &out_opt)?;
let out_opt = out_opt.unwrap_or_else(|| {
let default_engine = if card.is_gguf() {
gguf_default()
} else {
safetensors_default()
};
tracing::info!(
"Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
Output::available_engines().join(", ")
);
default_engine
});
print_cuda(&out_opt);
// Make an engine from the local_model, flags and output.
let (engine_config, extra) =
engine_for(runtime.primary_token(), out_opt, flags.clone(), local_model).await?;
// Create the engine matching `out`
let engine_config = match out_opt {
Output::Dynamic => {
// Sanity check - TODO probably make a general sanity check at start of method
if flags.context_length.is_some() {
anyhow::bail!("'--content-length' flag should only be used on the worker node, not on the ingress");
}
if flags.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
EngineConfig::Dynamic
//
// Run in from an input
//
dynamo_llm::entrypoint::input::run_input(in_opt, runtime, engine_config).await?;
// Allow engines to ask main thread to wait on an extra future.
// We use this to stop the vllm and sglang sub-process
if let Some(extra) = extra {
extra.await;
}
Output::EchoFull => EngineConfig::StaticFull {
Ok(())
}
type ExtraFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
cancel_token: CancellationToken,
out_opt: Output,
flags: Flags,
local_model: LocalModel,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
match out_opt {
Output::Dynamic => Ok((EngineConfig::Dynamic(Box::new(local_model)), None)),
Output::EchoFull => Ok((
EngineConfig::StaticFull {
model: Box::new(local_model),
engine: dynamo_llm::engines::make_engine_full(),
},
Output::EchoCore => {
let card = local_model.card();
if !card.has_tokenizer() {
anyhow::bail!(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
None,
)),
Output::EchoCore => Ok((
EngineConfig::StaticCore {
engine: dynamo_llm::engines::make_engine_core(),
model: Box::new(local_model),
}
}
},
None,
)),
#[cfg(feature = "mistralrs")]
Output::MistralRs => EngineConfig::StaticFull {
Output::MistralRs => Ok((
EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model),
},
Output::SgLang => {
if !local_model.path().is_dir() {
// TODO Does sglang support GGUF? Can we make it work?
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
None,
)),
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => Ok((
EngineConfig::StaticCore {
engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?,
model: Box::new(local_model),
},
None,
)),
// For multi-node config. vllm uses `ray`, see guide
Output::Vllm => shell(subprocess::vllm::PY, cancel_token, local_model, flags, None).await,
// For multi-node config. trtlllm uses `mpi`, see guide
Output::Trtllm => {
shell(
subprocess::trtllm::PY,
cancel_token,
local_model,
flags,
None,
)
.await
}
// If `in=dyn` we want the sglang subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => internal_endpoint("sglang"),
};
let multi_node_conf = dynamo_llm::engines::MultiNodeConfig {
Output::SgLang => {
let multi_node_config = if flags.num_nodes > 1 {
Some(dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.clone().unwrap_or_default(),
};
let (py_script, child) = match subprocess::start(
subprocess::sglang::PY,
&local_model,
&endpoint,
flags.clone(),
if flags.num_nodes <= 1 {
None
})
} else {
Some(multi_node_conf)
},
)
.await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting sglang sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
Output::Vllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
// If `in=dyn` we want the vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => internal_endpoint("vllm"),
None
};
let (py_script, child) = match subprocess::start(
subprocess::vllm::PY,
&local_model,
&endpoint,
flags.clone(),
None, // multi-node config. vllm uses `ray`, see guide
shell(
subprocess::sglang::PY,
cancel_token,
local_model,
flags,
multi_node_config,
)
.await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting vllm sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
Output::Trtllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
// If `in=dyn` we want the trtllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => internal_endpoint("trtllm"),
};
let (py_script, child) = match subprocess::start(
subprocess::trtllm::PY,
&local_model,
&endpoint,
flags.clone(),
None, // multi-node config. trtlllm uses `mpi`, see guide
)
.await
{
async fn shell(
py_script: &'static str,
cancel_token: CancellationToken,
local_model: LocalModel,
flags: Flags,
multi_node_config: Option<dynamo_llm::engines::MultiNodeConfig>,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
let (py_script, child) =
match subprocess::start(py_script, &local_model, flags.clone(), multi_node_config).await {
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting trtllm sub-process: {err}");
anyhow::bail!("Failed starting engine sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
let extra: ExtraFuture = Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
if !local_model.path().is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
let engine =
dynamo_engine_llamacpp::make_engine(cancel_token.clone(), &local_model).await?;
EngineConfig::StaticCore {
engine,
model: Box::new(local_model),
}
}
};
match in_opt {
Input::Http => {
crate::input::http::run(runtime.clone(), flags, engine_config, template).await?;
}
Input::Text => {
crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(
runtime.clone(),
flags,
Some(prompt),
engine_config,
template,
)
.await?;
}
Input::Batch(path) => {
crate::input::batch::run(runtime.clone(), flags, card, path, engine_config, template)
.await?;
}
Input::Endpoint(path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
crate::input::endpoint::run(distributed_runtime, path, engine_config).await?;
}
}
// Allow engines to ask main thread to wait on an extra future.
// We use this to stop the vllm and sglang sub-process
if let Some(extra) = extra {
extra.await;
}
Ok(())
});
Ok((EngineConfig::Dynamic(Box::new(local_model)), Some(extra)))
}
/// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible.
......@@ -341,21 +197,21 @@ async fn stopper(
tokio::select! {
exit = child.wait() => {
tracing::trace!("vllm sub-process graceful exit");
tracing::trace!("engine sub-process graceful exit");
match exit {
Ok(exit_status) if exit_status.success() => {}
Ok(exit_status) => {
// This is nearly always 15 (SIGTERM)
tracing::trace!("vllm sub-process non-0 exit: {exit_status}");
tracing::trace!("engine sub-process non-0 exit: {exit_status}");
}
Err(err) => {
tracing::warn!("vllm sub-process error getting exit status: {err}");
tracing::warn!("engine sub-process error getting exit status: {err}");
}
}
}
_ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => {
// It didn't stop in time, kill it
child.kill().await.expect("Failed killing vllm subprocess");
child.kill().await.expect("Failed killing engine subprocess");
let _ = child.wait().await;
}
}
......@@ -400,6 +256,19 @@ fn print_cuda(output: &Output) {
#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: &Output) {}
fn default_engine_for(local_model: &LocalModel) -> Output {
let default_engine = if local_model.card().is_gguf() {
gguf_default()
} else {
safetensors_default()
};
tracing::info!(
"Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
Output::available_engines().join(", ")
);
default_engine
}
fn gguf_default() -> Output {
#[cfg(feature = "llamacpp")]
{
......@@ -428,13 +297,3 @@ fn safetensors_default() -> Output {
Output::EchoFull
}
}
/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
EndpointId {
namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
component: engine.to_string(),
name: "generate".to_string(),
}
}
......@@ -17,7 +17,8 @@ use std::env;
use clap::Parser;
use dynamo_run::{Input, Output};
use dynamo_llm::entrypoint::input::Input;
use dynamo_run::Output;
use dynamo_runtime::logging;
const HELP: &str = r#"
......@@ -127,5 +128,17 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
.chain(env::args().skip(non_flag_params)),
)?;
if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) {
anyhow::bail!("Cannot use endpoint for both in and out");
}
dynamo_run::run(runtime, in_opt, out_opt, flags).await
}
fn is_in_dynamic(in_opt: &Input) -> bool {
matches!(in_opt, Input::Endpoint(_))
}
fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Dynamic))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{fmt, io::IsTerminal as _, path::PathBuf};
use dynamo_runtime::protocols::ENDPOINT_SCHEME;
const BATCH_PREFIX: &str = "batch:";
#[derive(PartialEq)]
pub enum Input {
/// Run an OpenAI compatible HTTP server
Http,
/// Single prompt on stdin
Stdin,
/// Interactive chat
Text,
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
/// Batch mode. Run all the prompts, write the outputs, exit.
Batch(PathBuf),
}
impl TryFrom<&str> for Input {
type Error = anyhow::Error;
fn try_from(s: &str) -> anyhow::Result<Self> {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"stdin" => Ok(Input::Stdin),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
Ok(Input::Endpoint(endpoint_path.to_string()))
}
batch_patch if batch_patch.starts_with(BATCH_PREFIX) => {
let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap();
Ok(Input::Batch(PathBuf::from(path)))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
}
}
}
impl fmt::Display for Input {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
Input::Http => "http",
Input::Text => "text",
Input::Stdin => "stdin",
Input::Endpoint(path) => path,
Input::Batch(path) => &path.display().to_string(),
};
write!(f, "{s}")
}
}
impl Default for Input {
fn default() -> Self {
if std::io::stdin().is_terminal() {
Input::Text
} else {
Input::Stdin
}
}
}
use std::fmt;
pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response
......
......@@ -13,7 +13,6 @@ use tokio::io::AsyncBufReadExt;
use crate::flags::RouterMode;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang;
pub mod trtllm;
......@@ -24,8 +23,6 @@ pub async fn start(
py_script: &'static str,
// Model info
local_model: &LocalModel,
// Endpoint to connect the subprocess over etcd/nats
endpoint: &EndpointId,
// Command line flags for user overrides
flags: super::Flags,
// sglang multi-node config. vllm uses `ray` externally
......@@ -40,7 +37,7 @@ pub async fn start(
let mut args = vec![
script_path.to_string_lossy().to_string(),
"--endpoint".to_string(),
endpoint.as_url(),
local_model.endpoint_id().as_url(),
"--model-path".to_string(),
local_model.path().to_string_lossy().to_string(),
"--model-name".to_string(),
......
......@@ -6,7 +6,7 @@ use std::sync::Arc;
use clap::{Parser, Subcommand};
use dynamo_llm::discovery::{ModelManager, ModelWatcher};
use dynamo_llm::local_model::{LocalModel, ModelNetworkName};
use dynamo_llm::local_model::{LocalModelBuilder, ModelNetworkName};
use dynamo_llm::model_type::ModelType;
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::pipeline::RouterMode;
......@@ -227,7 +227,10 @@ async fn add_model(
let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?;
let mut model = LocalModel::with_name_only(&model_name);
let mut model = LocalModelBuilder::default()
.model_name(Some(model_name))
.build()
.await?;
model.attach(&endpoint, model_type).await?;
Ok(())
......
......@@ -96,7 +96,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
match result {
Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize)
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
}) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
......@@ -139,7 +139,7 @@ fn dynamo_create_kv_publisher(
namespace: String,
component: String,
worker_id: i64,
kv_block_size: usize,
kv_block_size: u32,
) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component);
match DRT
......@@ -158,7 +158,7 @@ fn kv_event_create_stored_block_from_parts(
block_hash: u64,
token_ids: *const u32,
num_tokens: usize,
kv_block_size: usize,
kv_block_size: u32,
_lora_id: u64,
) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq(
......@@ -174,7 +174,7 @@ static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
fn kv_event_create_stored_from_parts(
kv_params: DynamoKvStoredEventParams,
kv_block_size: usize,
kv_block_size: u32,
) -> KvCacheEvent {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
......@@ -188,7 +188,7 @@ fn kv_event_create_stored_from_parts(
.offset(block_idx.try_into().unwrap())
};
if num_toks != kv_block_size {
if num_toks != (kv_block_size as usize) {
if WARN_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
if c < 3 {
......
......@@ -1011,6 +1011,18 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dialoguer"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
dependencies = [
"console",
"shell-words",
"tempfile",
"thiserror 1.0.69",
]
[[package]]
name = "digest"
version = "0.10.7"
......@@ -1123,6 +1135,7 @@ dependencies = [
"cudarc",
"derive-getters",
"derive_builder",
"dialoguer",
"dynamo-runtime",
"either",
"erased-serde",
......@@ -1131,6 +1144,7 @@ dependencies = [
"galil-seiferas",
"ggus",
"hf-hub",
"humantime",
"itertools 0.14.0",
"memmap2",
"minijinja",
......@@ -4224,6 +4238,12 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "shell-words"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]]
name = "shlex"
version = "1.3.0"
......
......@@ -9,6 +9,7 @@ use pyo3::types::{PyDict, PyList, PyString};
use pyo3::IntoPyObjectExt;
use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress;
use std::path::PathBuf;
use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex;
......@@ -104,8 +105,8 @@ fn register_llm<'p>(
endpoint: Endpoint,
model_path: &str,
model_name: Option<&str>,
context_length: Option<usize>,
kv_cache_block_size: Option<usize>,
context_length: Option<u32>,
kv_cache_block_size: Option<u32>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -117,18 +118,14 @@ fn register_llm<'p>(
let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder
.model_path(Some(PathBuf::from(inner_path)))
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size);
// Download from HF, load the ModelDeploymentCard
let mut local_model =
llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name)
.await
.map_err(to_pyerr)?;
if let Some(context_length) = context_length {
local_model.set_context_length(context_length);
}
if let Some(kv_cache_block_size) = kv_cache_block_size {
local_model.set_kv_cache_block_size(kv_cache_block_size);
}
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
local_model
.attach(&endpoint.inner, model_type_obj)
......
......@@ -40,8 +40,11 @@ impl KvRouter {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner =
llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None)
let inner = llm_rs::kv_router::KvRouter::new(
component.inner.clone(),
kv_block_size as u32,
None,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
......@@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) ->
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size);
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32);
Ok(hashes.into_iter().map(|h| h.0).collect())
}
......@@ -191,7 +194,7 @@ impl ZmqKvEventPublisher {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
config.worker_id,
config.kv_block_size,
config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
topic: config.zmq_topic,
......@@ -232,7 +235,7 @@ impl ZmqKvEventListener {
zmq_topic,
tx,
shutdown_token.clone(),
kv_block_size,
kv_block_size as u32,
));
Ok(Self {
......@@ -293,7 +296,7 @@ impl KvEventPublisher {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
worker_id,
kv_block_size,
kv_block_size as u32,
None,
)
.map_err(to_pyerr)?;
......@@ -322,7 +325,7 @@ impl KvEventPublisher {
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
blocks: create_stored_blocks(
self.kv_block_size,
self.kv_block_size as u32,
&token_ids,
&num_block_tokens,
&block_hashes,
......@@ -446,7 +449,7 @@ impl KvIndexer {
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
kv_block_size,
kv_block_size as u32,
)
.into();
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
......@@ -478,7 +481,7 @@ impl KvIndexer {
}
fn block_size(&self) -> usize {
self.inner.block_size()
self.inner.block_size() as usize
}
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
......
......@@ -78,7 +78,7 @@ impl LlamacppEngine {
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
let llama_ctx_params = if model_config.card().context_length > 0 {
let n_ctx = NonZeroU32::new(model_config.card().context_length as u32);
let n_ctx = NonZeroU32::new(model_config.card().context_length);
LlamaContextParams::default().with_n_ctx(n_ctx)
} else {
// Context length defaults to 512 currently
......
......@@ -128,7 +128,7 @@ impl MistralRsEngine {
.build(None)?
};
let mut max_seq_len = model.card().context_length;
let mut max_seq_len = model.card().context_length as usize;
if max_seq_len == 0 {
tracing::info!("context_length is 0. Probably error reading from model.");
max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
......
......@@ -53,6 +53,7 @@ either = { workspace = true }
etcd-client = { workspace = true }
futures = { workspace = true }
hf-hub = { workspace = true }
humantime = { workspace = true } # input/batch
rand = { workspace = true }
oneshot = { workspace = true }
prometheus = { workspace = true }
......@@ -80,6 +81,9 @@ offset-allocator = "0.2"
regex = "1"
rayon = "1"
# input/text
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
# block_manager
nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "a7c654d46a14cd5ce635cc8c02433d71df93dedf", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
......
......@@ -1605,7 +1605,7 @@ mod tests {
use dynamo_runtime::logging::init as init_logging;
use nixl_sys::Agent as NixlAgent;
const BLOCK_SIZE: usize = 4;
const BLOCK_SIZE: u32 = 4;
const SALT_HASH: SaltHash = 12345;
// Helper to create a default reset block
......@@ -1666,7 +1666,7 @@ mod tests {
// Extend to fill capacity
assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4
assert_eq!(block.len(), BLOCK_SIZE);
assert_eq!(block.len(), BLOCK_SIZE as usize);
// Append when full (should fail)
assert!(block.add_token(5).is_err(), "Append on full Partial block");
......@@ -1690,7 +1690,7 @@ mod tests {
// Fill block again for commit
assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok());
assert_eq!(block.len(), BLOCK_SIZE);
assert_eq!(block.len(), BLOCK_SIZE as usize);
// --- Partial -> Complete (via commit) --- //
assert!(block.commit().is_ok());
......
......@@ -43,7 +43,7 @@ impl BlockState {
return Err(BlockStateInvalid("Block is not reset".to_string()));
}
let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash);
let block = PartialTokenBlock::create_sequence_root(page_size as u32, salt_hash);
*self = BlockState::Partial(PartialState::new(block));
Ok(())
}
......
......@@ -648,7 +648,7 @@ pub(crate) mod tests {
/// Each block is initialized to the Complete state and then Registered.
pub fn create_blocks(
tokens: Tokens,
block_size: usize,
block_size: u32,
async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
let (token_blocks, _partial_token_block) =
......@@ -691,7 +691,7 @@ pub(crate) mod tests {
pub fn acquire_blocks(
tokens: Tokens,
block_size: usize,
block_size: u32,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
......@@ -749,7 +749,7 @@ pub(crate) mod tests {
let async_runtime = tokio::runtime::Runtime::new().unwrap();
const PAGE_SIZE: usize = 2;
const PAGE_SIZE: u32 = 2;
let mut pool = create_block_pool(10);
assert_eq!(pool.total_blocks(), 10);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment