"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "aa16ccf545f30d637747c88de41a8fcdaa65ab78"
Unverified Commit 92f06b0e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(dynamo-run): Refactor to library (#1687)

Move much of what was in the `dynamo-run` crate into `dynamo-llm` so that everyone can use it.

Example usage:

1. Create a `LocalModel`:

```
    let local_model = LocalModelBuilder::default()
	.model_path("Qwen/Qwen3-0.6B")
	.http_port(8080)
	.build().await?;
```

2. Make an engine:

```
    let engine_config = EngineConfig::StaticFull {
	engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
	model: Box::new(local_model),
    };
```

3. Connect it to an input and run it

```
    dynamo_llm::entrypoint::input::run_input(Input::Http, runtime, engine_config).await?;
```

For https://github.com/ai-dynamo/dynamo/issues/1647

Code Rabbit summary, thanks:
  * Introduced a flexible builder pattern for local model configuration, allowing advanced customization and easier initialization.
  * Added new input modes and unified input handling, supporting interactive chat, HTTP server, batch file, and distributed endpoint modes.
  * Centralized engine configuration and routing, enabling more extensible and maintainable engine management.
  * Simplified and modularized the codebase by moving input and engine logic into dedicated modules.
  * Replaced direct model construction with an asynchronous builder for improved clarity and extensibility.
  * Streamlined configuration and validation for flags and router settings.
  * Added validation to prevent incompatible input and output combinations in endpoint and dynamic modes.
parent 3b62692f
...@@ -1619,6 +1619,7 @@ dependencies = [ ...@@ -1619,6 +1619,7 @@ dependencies = [
"cudarc 0.16.2", "cudarc 0.16.2",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
"dialoguer",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"erased-serde", "erased-serde",
...@@ -1627,6 +1628,7 @@ dependencies = [ ...@@ -1627,6 +1628,7 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"ggus", "ggus",
"hf-hub", "hf-hub",
"humantime",
"insta", "insta",
"itertools 0.14.0", "itertools 0.14.0",
"lazy_static", "lazy_static",
...@@ -1677,14 +1679,12 @@ dependencies = [ ...@@ -1677,14 +1679,12 @@ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"clap", "clap",
"dialoguer",
"dynamo-engine-llamacpp", "dynamo-engine-llamacpp",
"dynamo-engine-mistralrs", "dynamo-engine-mistralrs",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
"futures", "futures",
"futures-util", "futures-util",
"humantime",
"libc", "libc",
"regex", "regex",
"serde", "serde",
......
...@@ -47,7 +47,7 @@ struct Args { ...@@ -47,7 +47,7 @@ struct Args {
/// Block size for the router /// Block size for the router
#[arg(long)] #[arg(long)]
block_size: usize, block_size: u32,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
...@@ -88,7 +88,7 @@ impl WorkerSelector for CustomWorkerSelector { ...@@ -88,7 +88,7 @@ impl WorkerSelector for CustomWorkerSelector {
&self, &self,
workers: &ProcessedEndpoints, workers: &ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: usize, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
// customize logic here // customize logic here
// F12 into [DefaultWorkerSelector] to see the original logic // F12 into [DefaultWorkerSelector] to see the original logic
......
...@@ -34,7 +34,6 @@ anyhow = { workspace = true } ...@@ -34,7 +34,6 @@ anyhow = { workspace = true }
async-stream = { workspace = true } async-stream = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
humantime = { workspace = true }
libc = { workspace = true } libc = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
...@@ -47,7 +46,6 @@ uuid = { workspace = true } ...@@ -47,7 +46,6 @@ uuid = { workspace = true }
async-openai = { workspace = true } async-openai = { workspace = true }
clap = { version = "4.5", features = ["derive", "env"] } clap = { version = "4.5", features = ["derive", "env"] }
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
futures-util = { version = "0.3" } futures-util = { version = "0.3" }
regex = "1" regex = "1"
......
...@@ -17,9 +17,13 @@ use std::collections::HashMap; ...@@ -17,9 +17,13 @@ use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use clap::ValueEnum; use clap::ValueEnum;
use dynamo_llm::entrypoint::RouterConfig;
use dynamo_llm::kv_router::KvRouterConfig; use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode; use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
use crate::Output;
/// Required options depend on the in and out choices /// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
...@@ -125,11 +129,11 @@ pub struct Flags { ...@@ -125,11 +129,11 @@ pub struct Flags {
/// context length (e.g. Llama 4). /// context length (e.g. Llama 4).
/// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json. /// Defaults to the model's max, which is usually model_max_length in tokenizer_config.json.
#[arg(long)] #[arg(long)]
pub context_length: Option<usize>, pub context_length: Option<u32>,
/// KV cache block size (vllm only) /// KV cache block size (vllm only)
#[arg(long)] #[arg(long)]
pub kv_cache_block_size: Option<usize>, pub kv_cache_block_size: Option<u32>,
/// Additional engine-specific arguments from a JSON file. /// Additional engine-specific arguments from a JSON file.
/// Contains a mapping of parameter names to values. /// Contains a mapping of parameter names to values.
...@@ -154,68 +158,65 @@ pub struct Flags { ...@@ -154,68 +158,65 @@ pub struct Flags {
} }
impl Flags { impl Flags {
/// Get KV router configuration /// For each Output variant, check if it would be able to run.
pub fn kv_router_config(&self) -> KvRouterConfig { /// This takes validation out of the main engine creation path.
pub fn validate(&self, local_model: &LocalModel, out_opt: &Output) -> anyhow::Result<()> {
match out_opt {
Output::Dynamic => {
if self.context_length.is_some() {
anyhow::bail!("'--context-length' flag should only be used on the worker node, not on the ingress");
}
if self.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
}
Output::EchoFull => {}
Output::EchoCore => {
if !local_model.card().has_tokenizer() {
anyhow::bail!(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
Output::SgLang => {
if !local_model.path().is_dir() {
// TODO GGUF support for sglang: https://github.com/ai-dynamo/dynamo/issues/572
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
}
Output::Vllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
Output::Trtllm => {
if self.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
if !local_model.path().is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
}
}
Ok(())
}
pub fn router_config(&self) -> RouterConfig {
RouterConfig::new(
self.router_mode.into(),
KvRouterConfig::new( KvRouterConfig::new(
self.kv_overlap_score_weight, self.kv_overlap_score_weight,
self.kv_gpu_cache_usage_weight, self.kv_gpu_cache_usage_weight,
self.kv_waiting_requests_weight, self.kv_waiting_requests_weight,
),
) )
} }
/// Convert the flags back to a command line. Including only the non-null values, but
/// include the defaults. Includes the canonicalized model path and normalized model name.
///
/// Used to pass arguments to python engines via `pystr` and `pytok`.
pub fn as_vec(&self, path: &str, name: &str) -> Vec<String> {
let mut out = vec![
"--model-path".to_string(),
path.to_string(),
"--model-name".to_string(),
name.to_string(),
"--http-port".to_string(),
self.http_port.to_string(),
// Default 1
"--tensor-parallel-size".to_string(),
self.tensor_parallel_size.to_string(),
// Default 0
"--base-gpu-id".to_string(),
self.base_gpu_id.to_string(),
// Default 1
"--num-nodes".to_string(),
self.num_nodes.to_string(),
// Default 0
"--node-rank".to_string(),
self.node_rank.to_string(),
];
if let Some(model_config_path) = self.model_config.as_ref() {
out.push("--model-config".to_string());
out.push(model_config_path.display().to_string());
}
if let Some(leader) = self.leader_addr.as_ref() {
out.push("--leader-addr".to_string());
out.push(leader.to_string());
}
if let Some(extra_engine_args) = self.extra_engine_args.as_ref() {
out.push("--extra-engine-args".to_string());
out.push(extra_engine_args.display().to_string());
}
if let Some(weight) = self.kv_overlap_score_weight {
out.push("--kv-overlap-score-weight".to_string());
out.push(weight.to_string());
}
if let Some(weight) = self.kv_gpu_cache_usage_weight {
out.push("--kv-gpu-cache-usage-weight".to_string());
out.push(weight.to_string());
}
if let Some(weight) = self.kv_waiting_requests_weight {
out.push("--kv-waiting-requests-weight".to_string());
out.push(weight.to_string());
}
out.extend(self.last.clone());
out
}
/// Load extra engine arguments from a JSON file /// Load extra engine arguments from a JSON file
/// Returns a HashMap of parameter names to values /// Returns a HashMap of parameter names to values
pub fn load_extra_engine_args( pub fn load_extra_engine_args(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod batch;
mod common;
pub mod endpoint;
pub mod http;
pub mod text;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::time::Duration;
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc, time::Duration};
use anyhow::Context; use anyhow::Context as _;
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel}; use dynamo_llm::entrypoint::input::Input;
use dynamo_runtime::protocols::Endpoint as EndpointId; use dynamo_llm::entrypoint::EngineConfig;
use dynamo_runtime::slug::Slug; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::{CancellationToken, DistributedRuntime}; use dynamo_runtime::CancellationToken;
mod flags; mod flags;
pub use flags::Flags; pub use flags::Flags;
mod input;
mod opt; mod opt;
pub use dynamo_llm::request_template::RequestTemplate; pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output}; pub use opt::Output;
mod subprocess; mod subprocess;
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2); const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
/// Default size of a KV cache block. Override with --kv-cache-block-size
const DEFAULT_KV_CACHE_BLOCK_SIZE: usize = 16;
pub enum EngineConfig {
/// Remote networked engines
Dynamic,
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
engine: ExecutionContext,
model: Box<LocalModel>,
},
}
fn is_in_dynamic(in_opt: &Input) -> bool {
matches!(in_opt, Input::Endpoint(_))
}
fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Dynamic))
}
pub async fn run( pub async fn run(
runtime: dynamo_runtime::Runtime, runtime: dynamo_runtime::Runtime,
in_opt: Input, in_opt: Input,
out_opt: Option<Output>, out_opt: Option<Output>,
flags: Flags, flags: Flags,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) { //
anyhow::bail!("Cannot use endpoint for both in and out"); // Configure
} //
let cancel_token = runtime.primary_token(); let mut builder = LocalModelBuilder::default();
let maybe_path = flags builder
.model_path(
flags
.model_path_pos .model_path_pos
.clone() .clone()
.or(flags.model_path_flag.clone()); .or(flags.model_path_flag.clone()),
let mut local_model: LocalModel = if is_out_dynamic(&out_opt) {
// If output is dynamic we are ingress and don't have a local model, but making an
// empty one cleans up the code.
Default::default()
} else {
// All other output types have a local model
match &maybe_path {
Some(model_path) => {
LocalModel::prepare(
model_path.to_str().context("Invalid UTF-8 in model path")?,
flags.model_config.as_deref(),
flags.model_name.clone(),
) )
.await? .model_name(flags.model_name.clone())
} .kv_cache_block_size(flags.kv_cache_block_size)
None => { // Only set if user provides. Usually loaded from tokenizer_config.json
// echo_full engine doesn't need a path .context_length(flags.context_length)
match &flags.model_name { .http_port(flags.http_port)
Some(name) => LocalModel::with_name_only(name), .router_config(flags.router_config())
None => Default::default(), .request_template(flags.request_template.clone());
}
} // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
} // If not, then the endpoint isn't exposed so we let LocalModel invent one.
if let Input::Endpoint(path) = &in_opt {
builder.endpoint_id(path.parse().with_context(|| path.clone())?);
}; };
// Only set if user provides. Usually loaded from tokenizer_config.json let local_model = builder.build().await?;
if let Some(context_length) = flags.context_length {
local_model.set_context_length(context_length);
}
// Always set, there is no engine provided default
local_model.set_kv_cache_block_size(
flags
.kv_cache_block_size
.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE),
);
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process //
// Create an engine
//
let template = if let Some(path) = flags.request_template.as_ref() { let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
let template = RequestTemplate::load(path)?; print_cuda(&out_opt);
tracing::debug!("Using request template: {template:?}");
Some(template)
} else {
None
};
// We may need it later // Now that we know the output we're targeting, check if we expect it to work
let card = local_model.card().clone(); flags.validate(&local_model, &out_opt)?;
let out_opt = out_opt.unwrap_or_else(|| { // Make an engine from the local_model, flags and output.
let default_engine = if card.is_gguf() { let (engine_config, extra) =
gguf_default() engine_for(runtime.primary_token(), out_opt, flags.clone(), local_model).await?;
} else {
safetensors_default()
};
tracing::info!(
"Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
Output::available_engines().join(", ")
);
default_engine
});
print_cuda(&out_opt);
// Create the engine matching `out` //
let engine_config = match out_opt { // Run in from an input
Output::Dynamic => { //
// Sanity check - TODO probably make a general sanity check at start of method
if flags.context_length.is_some() { dynamo_llm::entrypoint::input::run_input(in_opt, runtime, engine_config).await?;
anyhow::bail!("'--content-length' flag should only be used on the worker node, not on the ingress");
} // Allow engines to ask main thread to wait on an extra future.
if flags.kv_cache_block_size.is_some() { // We use this to stop the vllm and sglang sub-process
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress"); if let Some(extra) = extra {
} extra.await;
EngineConfig::Dynamic
} }
Output::EchoFull => EngineConfig::StaticFull {
Ok(())
}
type ExtraFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
cancel_token: CancellationToken,
out_opt: Output,
flags: Flags,
local_model: LocalModel,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
match out_opt {
Output::Dynamic => Ok((EngineConfig::Dynamic(Box::new(local_model)), None)),
Output::EchoFull => Ok((
EngineConfig::StaticFull {
model: Box::new(local_model), model: Box::new(local_model),
engine: dynamo_llm::engines::make_engine_full(), engine: dynamo_llm::engines::make_engine_full(),
}, },
Output::EchoCore => { None,
let card = local_model.card(); )),
if !card.has_tokenizer() { Output::EchoCore => Ok((
anyhow::bail!(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
EngineConfig::StaticCore { EngineConfig::StaticCore {
engine: dynamo_llm::engines::make_engine_core(), engine: dynamo_llm::engines::make_engine_core(),
model: Box::new(local_model), model: Box::new(local_model),
} },
} None,
)),
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => EngineConfig::StaticFull { Output::MistralRs => Ok((
EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?, engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model), model: Box::new(local_model),
}, },
Output::SgLang => { None,
if !local_model.path().is_dir() { )),
// TODO Does sglang support GGUF? Can we make it work? #[cfg(feature = "llamacpp")]
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout"); Output::LlamaCpp => Ok((
EngineConfig::StaticCore {
engine: dynamo_engine_llamacpp::make_engine(cancel_token, &local_model).await?,
model: Box::new(local_model),
},
None,
)),
// For multi-node config. vllm uses `ray`, see guide
Output::Vllm => shell(subprocess::vllm::PY, cancel_token, local_model, flags, None).await,
// For multi-node config. trtlllm uses `mpi`, see guide
Output::Trtllm => {
shell(
subprocess::trtllm::PY,
cancel_token,
local_model,
flags,
None,
)
.await
} }
Output::SgLang => {
// If `in=dyn` we want the sglang subprocess to listen on that endpoint. let multi_node_config = if flags.num_nodes > 1 {
// If not, then the endpoint isn't exposed so we invent an internal one. Some(dynamo_llm::engines::MultiNodeConfig {
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => internal_endpoint("sglang"),
};
let multi_node_conf = dynamo_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes, num_nodes: flags.num_nodes,
node_rank: flags.node_rank, node_rank: flags.node_rank,
leader_addr: flags.leader_addr.clone().unwrap_or_default(), leader_addr: flags.leader_addr.clone().unwrap_or_default(),
}; })
let (py_script, child) = match subprocess::start(
subprocess::sglang::PY,
&local_model,
&endpoint,
flags.clone(),
if flags.num_nodes <= 1 {
None
} else { } else {
Some(multi_node_conf) None
},
)
.await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting sglang sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
}
Output::Vllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
// If `in=dyn` we want the vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we invent an internal one.
let endpoint = match &in_opt {
Input::Endpoint(path) => path.parse()?,
_ => internal_endpoint("vllm"),
}; };
shell(
let (py_script, child) = match subprocess::start( subprocess::sglang::PY,
subprocess::vllm::PY, cancel_token,
&local_model, local_model,
&endpoint, flags,
flags.clone(), multi_node_config,
None, // multi-node config. vllm uses `ray`, see guide
) )
.await .await
{
Ok(x) => x,
Err(err) => {
anyhow::bail!("Failed starting vllm sub-process: {err}");
}
};
let cancel_token = cancel_token.clone();
// Sub-process cleanup
extra = Some(Box::pin(async move {
stopper(cancel_token, child, py_script).await;
}));
EngineConfig::Dynamic
} }
Output::Trtllm => {
if flags.base_gpu_id != 0 {
anyhow::bail!("TRTLLM does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
} }
}
// If `in=dyn` we want the trtllm subprocess to listen on that endpoint. async fn shell(
// If not, then the endpoint isn't exposed so we invent an internal one. py_script: &'static str,
let endpoint = match &in_opt { cancel_token: CancellationToken,
Input::Endpoint(path) => path.parse()?, local_model: LocalModel,
_ => internal_endpoint("trtllm"), flags: Flags,
}; multi_node_config: Option<dynamo_llm::engines::MultiNodeConfig>,
) -> anyhow::Result<(EngineConfig, Option<ExtraFuture>)> {
let (py_script, child) = match subprocess::start( let (py_script, child) =
subprocess::trtllm::PY, match subprocess::start(py_script, &local_model, flags.clone(), multi_node_config).await {
&local_model,
&endpoint,
flags.clone(),
None, // multi-node config. trtlllm uses `mpi`, see guide
)
.await
{
Ok(x) => x, Ok(x) => x,
Err(err) => { Err(err) => {
anyhow::bail!("Failed starting trtllm sub-process: {err}"); anyhow::bail!("Failed starting engine sub-process: {err}");
} }
}; };
let cancel_token = cancel_token.clone();
// Sub-process cleanup // Sub-process cleanup
extra = Some(Box::pin(async move { let extra: ExtraFuture = Box::pin(async move {
stopper(cancel_token, child, py_script).await; stopper(cancel_token, child, py_script).await;
})); });
EngineConfig::Dynamic Ok((EngineConfig::Dynamic(Box::new(local_model)), Some(extra)))
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
if !local_model.path().is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
let engine =
dynamo_engine_llamacpp::make_engine(cancel_token.clone(), &local_model).await?;
EngineConfig::StaticCore {
engine,
model: Box::new(local_model),
}
}
};
match in_opt {
Input::Http => {
crate::input::http::run(runtime.clone(), flags, engine_config, template).await?;
}
Input::Text => {
crate::input::text::run(runtime.clone(), flags, None, engine_config, template).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
crate::input::text::run(
runtime.clone(),
flags,
Some(prompt),
engine_config,
template,
)
.await?;
}
Input::Batch(path) => {
crate::input::batch::run(runtime.clone(), flags, card, path, engine_config, template)
.await?;
}
Input::Endpoint(path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
crate::input::endpoint::run(distributed_runtime, path, engine_config).await?;
}
}
// Allow engines to ask main thread to wait on an extra future.
// We use this to stop the vllm and sglang sub-process
if let Some(extra) = extra {
extra.await;
}
Ok(())
} }
/// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible. /// Wait for cancel_token to be cancelled, then stop the child as gracefully as possible.
...@@ -341,21 +197,21 @@ async fn stopper( ...@@ -341,21 +197,21 @@ async fn stopper(
tokio::select! { tokio::select! {
exit = child.wait() => { exit = child.wait() => {
tracing::trace!("vllm sub-process graceful exit"); tracing::trace!("engine sub-process graceful exit");
match exit { match exit {
Ok(exit_status) if exit_status.success() => {} Ok(exit_status) if exit_status.success() => {}
Ok(exit_status) => { Ok(exit_status) => {
// This is nearly always 15 (SIGTERM) // This is nearly always 15 (SIGTERM)
tracing::trace!("vllm sub-process non-0 exit: {exit_status}"); tracing::trace!("engine sub-process non-0 exit: {exit_status}");
} }
Err(err) => { Err(err) => {
tracing::warn!("vllm sub-process error getting exit status: {err}"); tracing::warn!("engine sub-process error getting exit status: {err}");
} }
} }
} }
_ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => { _ = tokio::time::sleep(CHILD_STOP_TIMEOUT) => {
// It didn't stop in time, kill it // It didn't stop in time, kill it
child.kill().await.expect("Failed killing vllm subprocess"); child.kill().await.expect("Failed killing engine subprocess");
let _ = child.wait().await; let _ = child.wait().await;
} }
} }
...@@ -400,6 +256,19 @@ fn print_cuda(output: &Output) { ...@@ -400,6 +256,19 @@ fn print_cuda(output: &Output) {
#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))] #[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: &Output) {} fn print_cuda(_output: &Output) {}
fn default_engine_for(local_model: &LocalModel) -> Output {
let default_engine = if local_model.card().is_gguf() {
gguf_default()
} else {
safetensors_default()
};
tracing::info!(
"Using default engine: {default_engine}. Use out=<engine> to specify one of {}",
Output::available_engines().join(", ")
);
default_engine
}
fn gguf_default() -> Output { fn gguf_default() -> Output {
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
{ {
...@@ -428,13 +297,3 @@ fn safetensors_default() -> Output { ...@@ -428,13 +297,3 @@ fn safetensors_default() -> Output {
Output::EchoFull Output::EchoFull
} }
} }
/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
EndpointId {
namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
component: engine.to_string(),
name: "generate".to_string(),
}
}
...@@ -17,7 +17,8 @@ use std::env; ...@@ -17,7 +17,8 @@ use std::env;
use clap::Parser; use clap::Parser;
use dynamo_run::{Input, Output}; use dynamo_llm::entrypoint::input::Input;
use dynamo_run::Output;
use dynamo_runtime::logging; use dynamo_runtime::logging;
const HELP: &str = r#" const HELP: &str = r#"
...@@ -127,5 +128,17 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> { ...@@ -127,5 +128,17 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
.chain(env::args().skip(non_flag_params)), .chain(env::args().skip(non_flag_params)),
)?; )?;
if is_in_dynamic(&in_opt) && is_out_dynamic(&out_opt) {
anyhow::bail!("Cannot use endpoint for both in and out");
}
dynamo_run::run(runtime, in_opt, out_opt, flags).await dynamo_run::run(runtime, in_opt, out_opt, flags).await
} }
fn is_in_dynamic(in_opt: &Input) -> bool {
matches!(in_opt, Input::Endpoint(_))
}
fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Dynamic))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{fmt, io::IsTerminal as _, path::PathBuf};
use dynamo_runtime::protocols::ENDPOINT_SCHEME; use dynamo_runtime::protocols::ENDPOINT_SCHEME;
use std::fmt;
const BATCH_PREFIX: &str = "batch:";
#[derive(PartialEq)]
pub enum Input {
/// Run an OpenAI compatible HTTP server
Http,
/// Single prompt on stdin
Stdin,
/// Interactive chat
Text,
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
/// Batch mode. Run all the prompts, write the outputs, exit.
Batch(PathBuf),
}
impl TryFrom<&str> for Input {
type Error = anyhow::Error;
fn try_from(s: &str) -> anyhow::Result<Self> {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"stdin" => Ok(Input::Stdin),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
Ok(Input::Endpoint(endpoint_path.to_string()))
}
batch_patch if batch_patch.starts_with(BATCH_PREFIX) => {
let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap();
Ok(Input::Batch(PathBuf::from(path)))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
}
}
}
impl fmt::Display for Input {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
Input::Http => "http",
Input::Text => "text",
Input::Stdin => "stdin",
Input::Endpoint(path) => path,
Input::Batch(path) => &path.display().to_string(),
};
write!(f, "{s}")
}
}
impl Default for Input {
fn default() -> Self {
if std::io::stdin().is_terminal() {
Input::Text
} else {
Input::Stdin
}
}
}
pub enum Output { pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response /// Accept un-preprocessed requests, echo the prompt back as the response
......
...@@ -13,7 +13,6 @@ use tokio::io::AsyncBufReadExt; ...@@ -13,7 +13,6 @@ use tokio::io::AsyncBufReadExt;
use crate::flags::RouterMode; use crate::flags::RouterMode;
use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::local_model::LocalModel; use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang; pub mod sglang;
pub mod trtllm; pub mod trtllm;
...@@ -24,8 +23,6 @@ pub async fn start( ...@@ -24,8 +23,6 @@ pub async fn start(
py_script: &'static str, py_script: &'static str,
// Model info // Model info
local_model: &LocalModel, local_model: &LocalModel,
// Endpoint to connect the subprocess over etcd/nats
endpoint: &EndpointId,
// Command line flags for user overrides // Command line flags for user overrides
flags: super::Flags, flags: super::Flags,
// sglang multi-node config. vllm uses `ray` externally // sglang multi-node config. vllm uses `ray` externally
...@@ -40,7 +37,7 @@ pub async fn start( ...@@ -40,7 +37,7 @@ pub async fn start(
let mut args = vec![ let mut args = vec![
script_path.to_string_lossy().to_string(), script_path.to_string_lossy().to_string(),
"--endpoint".to_string(), "--endpoint".to_string(),
endpoint.as_url(), local_model.endpoint_id().as_url(),
"--model-path".to_string(), "--model-path".to_string(),
local_model.path().to_string_lossy().to_string(), local_model.path().to_string_lossy().to_string(),
"--model-name".to_string(), "--model-name".to_string(),
......
...@@ -6,7 +6,7 @@ use std::sync::Arc; ...@@ -6,7 +6,7 @@ use std::sync::Arc;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use dynamo_llm::discovery::{ModelManager, ModelWatcher}; use dynamo_llm::discovery::{ModelManager, ModelWatcher};
use dynamo_llm::local_model::{LocalModel, ModelNetworkName}; use dynamo_llm::local_model::{LocalModelBuilder, ModelNetworkName};
use dynamo_llm::model_type::ModelType; use dynamo_llm::model_type::ModelType;
use dynamo_runtime::component::Endpoint; use dynamo_runtime::component::Endpoint;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
...@@ -227,7 +227,10 @@ async fn add_model( ...@@ -227,7 +227,10 @@ async fn add_model(
let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?; let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?;
let mut model = LocalModel::with_name_only(&model_name); let mut model = LocalModelBuilder::default()
.model_name(Some(model_name))
.build()
.await?;
model.attach(&endpoint, model_type).await?; model.attach(&endpoint, model_type).await?;
Ok(()) Ok(())
......
...@@ -96,7 +96,7 @@ pub unsafe extern "C" fn dynamo_llm_init( ...@@ -96,7 +96,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
match result { match result {
Ok(_) => match KV_PUB.get_or_try_init(move || { Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size as usize) dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
}) { }) {
Ok(_) => DynamoLlmResult::OK, Ok(_) => DynamoLlmResult::OK,
Err(e) => { Err(e) => {
...@@ -139,7 +139,7 @@ fn dynamo_create_kv_publisher( ...@@ -139,7 +139,7 @@ fn dynamo_create_kv_publisher(
namespace: String, namespace: String,
component: String, component: String,
worker_id: i64, worker_id: i64,
kv_block_size: usize, kv_block_size: u32,
) -> Result<KvEventPublisher, anyhow::Error> { ) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component); tracing::info!("Creating KV Publisher for model: {}", component);
match DRT match DRT
...@@ -158,7 +158,7 @@ fn kv_event_create_stored_block_from_parts( ...@@ -158,7 +158,7 @@ fn kv_event_create_stored_block_from_parts(
block_hash: u64, block_hash: u64,
token_ids: *const u32, token_ids: *const u32,
num_tokens: usize, num_tokens: usize,
kv_block_size: usize, kv_block_size: u32,
_lora_id: u64, _lora_id: u64,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq( let tokens_hash = compute_block_hash_for_seq(
...@@ -174,7 +174,7 @@ static WARN_COUNT: AtomicU32 = AtomicU32::new(0); ...@@ -174,7 +174,7 @@ static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
fn kv_event_create_stored_from_parts( fn kv_event_create_stored_from_parts(
kv_params: DynamoKvStoredEventParams, kv_params: DynamoKvStoredEventParams,
kv_block_size: usize, kv_block_size: u32,
) -> KvCacheEvent { ) -> KvCacheEvent {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new(); let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
...@@ -188,7 +188,7 @@ fn kv_event_create_stored_from_parts( ...@@ -188,7 +188,7 @@ fn kv_event_create_stored_from_parts(
.offset(block_idx.try_into().unwrap()) .offset(block_idx.try_into().unwrap())
}; };
if num_toks != kv_block_size { if num_toks != (kv_block_size as usize) {
if WARN_COUNT if WARN_COUNT
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| { .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |c| {
if c < 3 { if c < 3 {
......
...@@ -1011,6 +1011,18 @@ dependencies = [ ...@@ -1011,6 +1011,18 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "dialoguer"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de"
dependencies = [
"console",
"shell-words",
"tempfile",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
...@@ -1123,6 +1135,7 @@ dependencies = [ ...@@ -1123,6 +1135,7 @@ dependencies = [
"cudarc", "cudarc",
"derive-getters", "derive-getters",
"derive_builder", "derive_builder",
"dialoguer",
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"erased-serde", "erased-serde",
...@@ -1131,6 +1144,7 @@ dependencies = [ ...@@ -1131,6 +1144,7 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"ggus", "ggus",
"hf-hub", "hf-hub",
"humantime",
"itertools 0.14.0", "itertools 0.14.0",
"memmap2", "memmap2",
"minijinja", "minijinja",
...@@ -4224,6 +4238,12 @@ dependencies = [ ...@@ -4224,6 +4238,12 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shell-words"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
......
...@@ -9,6 +9,7 @@ use pyo3::types::{PyDict, PyList, PyString}; ...@@ -9,6 +9,7 @@ use pyo3::types::{PyDict, PyList, PyString};
use pyo3::IntoPyObjectExt; use pyo3::IntoPyObjectExt;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress; use rs::pipeline::network::Ingress;
use std::path::PathBuf;
use std::{fmt::Display, sync::Arc}; use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
...@@ -104,8 +105,8 @@ fn register_llm<'p>( ...@@ -104,8 +105,8 @@ fn register_llm<'p>(
endpoint: Endpoint, endpoint: Endpoint,
model_path: &str, model_path: &str,
model_name: Option<&str>, model_name: Option<&str>,
context_length: Option<usize>, context_length: Option<u32>,
kv_cache_block_size: Option<usize>, kv_cache_block_size: Option<u32>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
...@@ -117,18 +118,14 @@ fn register_llm<'p>( ...@@ -117,18 +118,14 @@ fn register_llm<'p>(
let inner_path = model_path.to_string(); let inner_path = model_path.to_string();
let model_name = model_name.map(|n| n.to_string()); let model_name = model_name.map(|n| n.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder
.model_path(Some(PathBuf::from(inner_path)))
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size);
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = let mut local_model = builder.build().await.map_err(to_pyerr)?;
llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name)
.await
.map_err(to_pyerr)?;
if let Some(context_length) = context_length {
local_model.set_context_length(context_length);
}
if let Some(kv_cache_block_size) = kv_cache_block_size {
local_model.set_kv_cache_block_size(kv_cache_block_size);
}
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
local_model local_model
.attach(&endpoint.inner, model_type_obj) .attach(&endpoint.inner, model_type_obj)
......
...@@ -40,8 +40,11 @@ impl KvRouter { ...@@ -40,8 +40,11 @@ impl KvRouter {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let inner = let inner = llm_rs::kv_router::KvRouter::new(
llm_rs::kv_router::KvRouter::new(component.inner.clone(), kv_block_size, None) component.inner.clone(),
kv_block_size as u32,
None,
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(Self { Ok(Self {
...@@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> ...@@ -73,7 +76,7 @@ pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) ->
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
} }
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size); let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32);
Ok(hashes.into_iter().map(|h| h.0).collect()) Ok(hashes.into_iter().map(|h| h.0).collect())
} }
...@@ -191,7 +194,7 @@ impl ZmqKvEventPublisher { ...@@ -191,7 +194,7 @@ impl ZmqKvEventPublisher {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner, component.inner,
config.worker_id, config.worker_id,
config.kv_block_size, config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq { Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint, endpoint: config.zmq_endpoint,
topic: config.zmq_topic, topic: config.zmq_topic,
...@@ -232,7 +235,7 @@ impl ZmqKvEventListener { ...@@ -232,7 +235,7 @@ impl ZmqKvEventListener {
zmq_topic, zmq_topic,
tx, tx,
shutdown_token.clone(), shutdown_token.clone(),
kv_block_size, kv_block_size as u32,
)); ));
Ok(Self { Ok(Self {
...@@ -293,7 +296,7 @@ impl KvEventPublisher { ...@@ -293,7 +296,7 @@ impl KvEventPublisher {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner, component.inner,
worker_id, worker_id,
kv_block_size, kv_block_size as u32,
None, None,
) )
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -322,7 +325,7 @@ impl KvEventPublisher { ...@@ -322,7 +325,7 @@ impl KvEventPublisher {
data: KvCacheEventData::Stored(KvCacheStoreData { data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash::from), parent_hash: parent_hash.map(ExternalSequenceBlockHash::from),
blocks: create_stored_blocks( blocks: create_stored_blocks(
self.kv_block_size, self.kv_block_size as u32,
&token_ids, &token_ids,
&num_block_tokens, &num_block_tokens,
&block_hashes, &block_hashes,
...@@ -446,7 +449,7 @@ impl KvIndexer { ...@@ -446,7 +449,7 @@ impl KvIndexer {
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> = let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new( llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(), component.inner.drt().runtime().child_token(),
kv_block_size, kv_block_size as u32,
) )
.into(); .into();
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // [gluo TODO] try subscribe_with_type::<RouterEvent>,
...@@ -478,7 +481,7 @@ impl KvIndexer { ...@@ -478,7 +481,7 @@ impl KvIndexer {
} }
fn block_size(&self) -> usize { fn block_size(&self) -> usize {
self.inner.block_size() self.inner.block_size() as usize
} }
fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> { fn find_matches<'p>(&self, py: Python<'p>, sequence: Vec<u64>) -> PyResult<Bound<'p, PyAny>> {
......
...@@ -78,7 +78,7 @@ impl LlamacppEngine { ...@@ -78,7 +78,7 @@ impl LlamacppEngine {
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS); let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
let llama_ctx_params = if model_config.card().context_length > 0 { let llama_ctx_params = if model_config.card().context_length > 0 {
let n_ctx = NonZeroU32::new(model_config.card().context_length as u32); let n_ctx = NonZeroU32::new(model_config.card().context_length);
LlamaContextParams::default().with_n_ctx(n_ctx) LlamaContextParams::default().with_n_ctx(n_ctx)
} else { } else {
// Context length defaults to 512 currently // Context length defaults to 512 currently
......
...@@ -128,7 +128,7 @@ impl MistralRsEngine { ...@@ -128,7 +128,7 @@ impl MistralRsEngine {
.build(None)? .build(None)?
}; };
let mut max_seq_len = model.card().context_length; let mut max_seq_len = model.card().context_length as usize;
if max_seq_len == 0 { if max_seq_len == 0 {
tracing::info!("context_length is 0. Probably error reading from model."); tracing::info!("context_length is 0. Probably error reading from model.");
max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN; max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
......
...@@ -53,6 +53,7 @@ either = { workspace = true } ...@@ -53,6 +53,7 @@ either = { workspace = true }
etcd-client = { workspace = true } etcd-client = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
hf-hub = { workspace = true } hf-hub = { workspace = true }
humantime = { workspace = true } # input/batch
rand = { workspace = true } rand = { workspace = true }
oneshot = { workspace = true } oneshot = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
...@@ -80,6 +81,9 @@ offset-allocator = "0.2" ...@@ -80,6 +81,9 @@ offset-allocator = "0.2"
regex = "1" regex = "1"
rayon = "1" rayon = "1"
# input/text
dialoguer = { version = "0.11", default-features = false, features = ["editor", "history"] }
# block_manager # block_manager
nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "a7c654d46a14cd5ce635cc8c02433d71df93dedf", optional = true } nixl-sys = {git="https://github.com/ai-dynamo/nixl", rev = "a7c654d46a14cd5ce635cc8c02433d71df93dedf", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true } cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
......
...@@ -1605,7 +1605,7 @@ mod tests { ...@@ -1605,7 +1605,7 @@ mod tests {
use dynamo_runtime::logging::init as init_logging; use dynamo_runtime::logging::init as init_logging;
use nixl_sys::Agent as NixlAgent; use nixl_sys::Agent as NixlAgent;
const BLOCK_SIZE: usize = 4; const BLOCK_SIZE: u32 = 4;
const SALT_HASH: SaltHash = 12345; const SALT_HASH: SaltHash = 12345;
// Helper to create a default reset block // Helper to create a default reset block
...@@ -1666,7 +1666,7 @@ mod tests { ...@@ -1666,7 +1666,7 @@ mod tests {
// Extend to fill capacity // Extend to fill capacity
assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4 assert!(block.add_tokens(Tokens::from(vec![4])).is_ok()); // 1, 2, 3, 4
assert_eq!(block.len(), BLOCK_SIZE); assert_eq!(block.len(), BLOCK_SIZE as usize);
// Append when full (should fail) // Append when full (should fail)
assert!(block.add_token(5).is_err(), "Append on full Partial block"); assert!(block.add_token(5).is_err(), "Append on full Partial block");
...@@ -1690,7 +1690,7 @@ mod tests { ...@@ -1690,7 +1690,7 @@ mod tests {
// Fill block again for commit // Fill block again for commit
assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok()); assert!(block.add_tokens(Tokens::from(vec![1, 2, 3, 4])).is_ok());
assert_eq!(block.len(), BLOCK_SIZE); assert_eq!(block.len(), BLOCK_SIZE as usize);
// --- Partial -> Complete (via commit) --- // // --- Partial -> Complete (via commit) --- //
assert!(block.commit().is_ok()); assert!(block.commit().is_ok());
......
...@@ -43,7 +43,7 @@ impl BlockState { ...@@ -43,7 +43,7 @@ impl BlockState {
return Err(BlockStateInvalid("Block is not reset".to_string())); return Err(BlockStateInvalid("Block is not reset".to_string()));
} }
let block = PartialTokenBlock::create_sequence_root(page_size, salt_hash); let block = PartialTokenBlock::create_sequence_root(page_size as u32, salt_hash);
*self = BlockState::Partial(PartialState::new(block)); *self = BlockState::Partial(PartialState::new(block));
Ok(()) Ok(())
} }
......
...@@ -648,7 +648,7 @@ pub(crate) mod tests { ...@@ -648,7 +648,7 @@ pub(crate) mod tests {
/// Each block is initialized to the Complete state and then Registered. /// Each block is initialized to the Complete state and then Registered.
pub fn create_blocks( pub fn create_blocks(
tokens: Tokens, tokens: Tokens,
block_size: usize, block_size: u32,
async_runtime: Handle, async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> { ) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
let (token_blocks, _partial_token_block) = let (token_blocks, _partial_token_block) =
...@@ -691,7 +691,7 @@ pub(crate) mod tests { ...@@ -691,7 +691,7 @@ pub(crate) mod tests {
pub fn acquire_blocks( pub fn acquire_blocks(
tokens: Tokens, tokens: Tokens,
block_size: usize, block_size: u32,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>, pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
async_runtime: Handle, async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) { ) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
...@@ -749,7 +749,7 @@ pub(crate) mod tests { ...@@ -749,7 +749,7 @@ pub(crate) mod tests {
let async_runtime = tokio::runtime::Runtime::new().unwrap(); let async_runtime = tokio::runtime::Runtime::new().unwrap();
const PAGE_SIZE: usize = 2; const PAGE_SIZE: u32 = 2;
let mut pool = create_block_pool(10); let mut pool = create_block_pool(10);
assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.total_blocks(), 10);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment