Unverified Commit 92f06b0e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(dynamo-run): Refactor to library (#1687)

Move much of what was in the `dynamo-run` crate into `dynamo-llm` so that everyone can use it.

Example usage:

1. Create a `LocalModel`:

```
    let local_model = LocalModelBuilder::default()
	.model_path("Qwen/Qwen3-0.6B")
	.http_port(8080)
	.build().await?;
```

2. Make an engine:

```
    let engine_config = EngineConfig::StaticFull {
	engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
	model: Box::new(local_model),
    };
```

3. Connect it to an input and run it

```
    dynamo_llm::entrypoint::input::run_input(Input::Http, runtime, engine_config).await?;
```

For https://github.com/ai-dynamo/dynamo/issues/1647

Code Rabbit summary, thanks:
  * Introduced a flexible builder pattern for local model configuration, allowing advanced customization and easier initialization.
  * Added new input modes and unified input handling, supporting interactive chat, HTTP server, batch file, and distributed endpoint modes.
  * Centralized engine configuration and routing, enabling more extensible and maintainable engine management.
  * Simplified and modularized the codebase by moving input and engine logic into dedicated modules.
  * Replaced direct model construction with an asynchronous builder for improved clarity and extensibility.
  * Streamlined configuration and validation for flags and router settings.
  * Added validation to prevent incompatible input and output combinations in endpoint and dynamic modes.
parent 3b62692f
......@@ -180,7 +180,7 @@ impl ModelManager {
&self,
model_name: &str,
component: &Component,
kv_cache_block_size: usize,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
......@@ -209,7 +209,7 @@ impl ModelManager {
&self,
model_name: &str,
component: &Component,
kv_cache_block_size: usize,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The entrypoint module provides tools to build a Dynamo runner.
//! - Create an EngineConfig of the engine (potentially auto-discovered) to execute
//! - Connect it to an Input
pub mod input;
use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode;
use crate::{
backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig,
local_model::LocalModel,
};
#[derive(Debug, Clone, Default)]
pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
}
impl RouterConfig {
pub fn new(router_mode: RouterMode, kv_router_config: KvRouterConfig) -> Self {
Self {
router_mode,
kv_router_config,
}
}
}
pub enum EngineConfig {
/// Remote networked engines
Dynamic(Box<LocalModel>),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
engine: ExecutionContext,
model: Box<LocalModel>,
},
}
impl EngineConfig {
fn local_model(&self) -> &LocalModel {
use EngineConfig::*;
match self {
Dynamic(lm) => lm,
StaticFull { model, .. } => model,
StaticCore { model, .. } => model,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! This module contains tools to gather a prompt from a user, forward it to an engine and return
//! the response.
//! See the Input enum for the inputs available. Input::Http (OpenAI compatible HTTP server)
//! and Input::Text (interactive chat) are good places to start.
//! The main entry point is `run_input`.
use std::{
fmt,
io::{IsTerminal as _, Read as _},
path::PathBuf,
};
pub mod batch;
mod common;
pub mod endpoint;
pub mod http;
pub mod text;
use dynamo_runtime::{protocols::ENDPOINT_SCHEME, DistributedRuntime};
const BATCH_PREFIX: &str = "batch:";
/// The various ways of connecting prompts to an engine
#[derive(PartialEq)]
pub enum Input {
/// Run an OpenAI compatible HTTP server
Http,
/// Single prompt on stdin
Stdin,
/// Interactive chat
Text,
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
/// Batch mode. Run all the prompts, write the outputs, exit.
Batch(PathBuf),
}
impl TryFrom<&str> for Input {
type Error = anyhow::Error;
fn try_from(s: &str) -> anyhow::Result<Self> {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"stdin" => Ok(Input::Stdin),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
Ok(Input::Endpoint(endpoint_path.to_string()))
}
batch_patch if batch_patch.starts_with(BATCH_PREFIX) => {
let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap();
Ok(Input::Batch(PathBuf::from(path)))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
}
}
}
impl fmt::Display for Input {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
Input::Http => "http",
Input::Text => "text",
Input::Stdin => "stdin",
Input::Endpoint(path) => path,
Input::Batch(path) => &path.display().to_string(),
};
write!(f, "{s}")
}
}
impl Default for Input {
fn default() -> Self {
if std::io::stdin().is_terminal() {
Input::Text
} else {
Input::Stdin
}
}
}
/// Run the given engine (EngineConfig) connected to an input.
/// Does not return until the input exits.
pub async fn run_input(
in_opt: Input,
runtime: dynamo_runtime::Runtime,
engine_config: super::EngineConfig,
) -> anyhow::Result<()> {
match in_opt {
Input::Http => {
http::run(runtime.clone(), engine_config).await?;
}
Input::Text => {
text::run(runtime.clone(), None, engine_config).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
text::run(runtime.clone(), Some(prompt), engine_config).await?;
}
Input::Batch(path) => {
batch::run(runtime.clone(), path, engine_config).await?;
}
Input::Endpoint(path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
endpoint::run(distributed_runtime, path, engine_config).await?;
}
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::request_template::RequestTemplate;
use dynamo_llm::types::openai::chat_completions::{
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
......@@ -19,8 +18,8 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::input::common;
use crate::{EngineConfig, Flags};
use crate::entrypoint::input::common;
use crate::entrypoint::EngineConfig;
/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
......@@ -53,11 +52,8 @@ struct Entry {
pub async fn run(
runtime: Runtime,
_flags: Flags,
card: ModelDeploymentCard,
input_jsonl: PathBuf,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory
......@@ -68,11 +64,10 @@ pub async fn run(
);
}
let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let service_name_ref = Arc::new(prepared_engine.service_name);
let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let pre_processor = if card.has_tokenizer() {
Some(OpenAIPreprocessor::new(card).await?)
let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?)
} else {
None
};
......@@ -85,6 +80,7 @@ pub async fn run(
tracing::error!(%err, "Failed writing output to {}", output_file.display());
}
});
let service_name_ref = Arc::new(prepared_engine.service_name);
let tokens_in = Arc::new(AtomicU64::new(0));
let tokens_out = Arc::new(AtomicU64::new(0));
......@@ -98,7 +94,7 @@ pub async fn run(
tracing::info!("Timer start.");
let start = Instant::now();
let mut lines = buffered_input.lines();
let template: Option<Arc<RequestTemplate>> = template.map(Arc::new);
let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() {
break;
......
......@@ -3,13 +3,15 @@
use std::pin::Pin;
use dynamo_llm::{
use crate::{
backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter,
entrypoint::EngineConfig,
model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
request_template::RequestTemplate,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
......@@ -25,12 +27,22 @@ use dynamo_runtime::{
};
use std::sync::Arc;
use crate::EngineConfig;
pub struct PreparedEngine {
pub service_name: String,
pub engine: OpenAIChatCompletionsStreamingEngine,
pub inspect_template: bool,
pub card: Option<ModelDeploymentCard>,
pub request_template: Option<RequestTemplate>,
}
impl PreparedEngine {
pub fn has_tokenizer(&self) -> bool {
if let Some(card) = self.card.as_ref() {
card.has_tokenizer()
} else {
false
}
}
}
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
......@@ -39,7 +51,7 @@ pub async fn prepare_engine(
engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic => {
EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let Some(etcd_client) = distributed_runtime.etcd_client() else {
......@@ -71,6 +83,8 @@ pub async fn prepare_engine(
service_name: model_service_name,
engine,
inspect_template: false,
card: None,
request_template: local_model.request_template(),
})
}
EngineConfig::StaticFull { engine, model } => {
......@@ -81,6 +95,8 @@ pub async fn prepare_engine(
service_name,
engine,
inspect_template: false,
request_template: model.request_template(),
card: Some(model.into_card()),
})
}
EngineConfig::StaticCore {
......@@ -99,6 +115,8 @@ pub async fn prepare_engine(
service_name,
engine: pipeline,
inspect_template: true,
request_template: model.request_template(),
card: Some(model.into_card()),
})
}
}
......@@ -137,21 +155,21 @@ where
#[cfg(test)]
mod tests {
use super::*;
use dynamo_llm::types::openai::{
use crate::types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
};
const HF_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../lib/llm/tests/data/sample-models/mock-llama-3.1-8b-instruct"
"/tests/data/sample-models/mock-llama-3.1-8b-instruct"
);
#[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?;
let engine = dynamo_llm::engines::make_engine_core();
let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions
let pipeline = build_pipeline::<
......@@ -170,7 +188,7 @@ mod tests {
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?;
let engine = dynamo_llm::engines::make_engine_core();
let engine = crate::engines::make_engine_core();
// Build pipeline for completions
let pipeline =
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{future::Future, pin::Pin, sync::Arc};
use dynamo_llm::{
use crate::{
backend::Backend,
engines::StreamingEngineAdapter,
model_type::ModelType,
......@@ -33,7 +21,7 @@ use dynamo_runtime::pipeline::{
};
use dynamo_runtime::{protocols::Endpoint as EndpointId, DistributedRuntime};
use crate::EngineConfig;
use crate::entrypoint::EngineConfig;
pub async fn run(
distributed_runtime: DistributedRuntime,
......@@ -91,7 +79,7 @@ pub async fn run(
(Box::pin(fut), Some(model.card().clone()))
}
EngineConfig::Dynamic => {
EngineConfig::Dynamic(_) => {
// We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a
// subprocess that we talk to like a remote endpoint.
// That means the vllm/sglang subprocess is doing all the work, we are idle.
......
......@@ -3,14 +3,12 @@
use std::sync::Arc;
use crate::input::common;
use crate::{EngineConfig, Flags};
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::{
use crate::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter,
entrypoint::{input::common, EngineConfig},
http::service::service_v2,
request_template::RequestTemplate,
kv_router::KvRouterConfig,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
......@@ -23,32 +21,28 @@ use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime};
/// Build and run an HTTP service
pub async fn run(
runtime: Runtime,
flags: Flags,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder()
.port(flags.http_port)
.port(engine_config.local_model().http_port())
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(template)
.with_request_template(engine_config.local_model().request_template())
.build()?;
match engine_config {
EngineConfig::Dynamic => {
EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
match distributed_runtime.etcd_client() {
Some(etcd_client) => {
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to HTTP service
run_watcher(
distributed_runtime,
http_service.state().manager_clone(),
etcd_client.clone(),
MODEL_ROOT_PATH,
flags.router_mode.into(),
Some(flags.kv_router_config()),
router_config.router_mode,
Some(router_config.kv_router_config.clone()),
)
.await?;
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::protocols::openai::nvext::NvExt;
use dynamo_llm::types::openai::chat_completions::{
use crate::protocols::openai::nvext::NvExt;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt;
use std::io::{ErrorKind, Write};
use crate::input::common;
use crate::{EngineConfig, Flags, RequestTemplate};
use crate::entrypoint::input::common;
use crate::entrypoint::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
......@@ -18,20 +19,19 @@ const MAX_TOKENS: u32 = 8192;
pub async fn run(
runtime: Runtime,
_flags: Flags,
single_prompt: Option<String>,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
// TODO: Pass prepared_engine directly
main_loop(
cancel_token,
&prepared_engine.service_name,
prepared_engine.engine,
single_prompt,
prepared_engine.inspect_template,
template,
prepared_engine.request_template,
)
.await
}
......
......@@ -50,7 +50,7 @@ pub trait WorkerSelector {
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
......@@ -104,13 +104,13 @@ impl KvRouterConfig {
pub struct KvRouter {
indexer: KvIndexer,
scheduler: KvScheduler,
block_size: usize,
block_size: u32,
}
impl KvRouter {
pub async fn new(
component: Component,
block_size: usize,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self> {
let cancellation_token = component
......@@ -196,7 +196,7 @@ impl KvRouter {
}
/// Get the block size this router was configured with
pub fn block_size(&self) -> usize {
pub fn block_size(&self) -> u32 {
self.block_size
}
}
......
......@@ -119,9 +119,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
/// ### Returns
///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> {
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements
.chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements
.map(|chunk| {
let bytes: Vec<u8> = chunk
.iter()
......@@ -527,7 +527,7 @@ pub struct KvIndexer {
/// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
kv_block_size: u32,
}
impl KvIndexer {
......@@ -544,7 +544,7 @@ impl KvIndexer {
pub fn new_with_frequency(
token: CancellationToken,
expiration_duration: Option<Duration>,
kv_block_size: usize,
kv_block_size: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
......@@ -611,11 +611,11 @@ impl KvIndexer {
}
}
pub fn block_size(&self) -> usize {
pub fn block_size(&self) -> u32 {
self.kv_block_size
}
pub fn new(token: CancellationToken, kv_block_size: usize) -> Self {
pub fn new(token: CancellationToken, kv_block_size: u32) -> Self {
Self::new_with_frequency(token, None, kv_block_size)
}
......@@ -697,7 +697,7 @@ pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown.
cancel: CancellationToken,
/// The size of the KV block this indexer can handle.
kv_block_size: usize,
kv_block_size: u32,
worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>,
......@@ -723,7 +723,7 @@ impl KvIndexerSharded {
token: CancellationToken,
num_shards: usize,
expiration_duration: Option<Duration>,
kv_block_size: usize,
kv_block_size: u32,
) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards];
......@@ -802,11 +802,11 @@ impl KvIndexerSharded {
}
}
pub fn block_size(&self) -> usize {
pub fn block_size(&self) -> u32 {
self.kv_block_size
}
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self {
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: u32) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size)
}
}
......@@ -1312,24 +1312,20 @@ mod tests {
#[case(11)]
#[case(32)]
#[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) {
fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
setup();
// create a sequence of 64 elements
let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>();
let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements
let sequence = (0..(kv_block_size + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements
let sequence = (0..(2 * kv_block_size + 1))
.map(|i| i as u32)
.collect::<Vec<u32>>();
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 2);
}
......@@ -1337,7 +1333,7 @@ mod tests {
fn make_indexer(
token: &CancellationToken,
num_shards: usize,
kv_block_size: usize,
kv_block_size: u32,
) -> Box<dyn KvIndexerInterface> {
if num_shards == 1 {
Box::new(KvIndexer::new(token.clone(), kv_block_size))
......@@ -1360,7 +1356,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) {
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
setup();
let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size);
......@@ -1368,7 +1364,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: usize) {
async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
......@@ -1381,7 +1377,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) {
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
......@@ -1394,7 +1390,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: usize) {
async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
setup();
let worker_id = 0;
......@@ -1409,7 +1405,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: usize) {
async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
setup();
let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
......@@ -1419,7 +1415,7 @@ mod tests {
#[tokio::test]
#[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: usize) {
async fn test_frequency(num_shards: usize, kv_block_size: u32) {
const ONE_MILLIS: Duration = Duration::from_millis(1);
setup();
......
......@@ -62,7 +62,7 @@ impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`].
fn start(
component: Component,
kv_block_size: usize,
kv_block_size: u32,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
......@@ -98,7 +98,7 @@ impl KvEventSource {
/// A publisher of KV events.
pub struct KvEventPublisher {
/// The size of the KV block.
kv_block_size: usize,
kv_block_size: u32,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source: Option<KvEventSource>,
......@@ -112,7 +112,7 @@ impl KvEventPublisher {
pub fn new(
component: Component,
worker_id: i64,
kv_block_size: usize,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
......@@ -155,7 +155,7 @@ impl KvEventPublisher {
self.tx.send(event)
}
pub fn kv_block_size(&self) -> usize {
pub fn kv_block_size(&self) -> u32 {
self.kv_block_size
}
......@@ -223,7 +223,7 @@ pub async fn start_zmq_listener(
zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken,
kv_block_size: usize,
kv_block_size: u32,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
......@@ -335,7 +335,7 @@ pub async fn start_zmq_listener(
fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: usize,
kv_block_size: u32,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
......@@ -382,7 +382,7 @@ fn convert_event(
}
pub fn create_stored_block_from_parts(
kv_block_size: usize,
kv_block_size: u32,
block_hash: i64,
token_ids: &[u32],
_lora_id: u64,
......@@ -395,7 +395,7 @@ pub fn create_stored_block_from_parts(
}
pub fn create_stored_blocks(
kv_block_size: usize,
kv_block_size: u32,
token_ids: &[u32],
num_block_tokens: &[u64],
block_hashes: &[i64],
......
......@@ -92,7 +92,7 @@ pub struct KvScheduler {
impl KvScheduler {
pub async fn start(
ns: Namespace,
block_size: usize,
block_size: u32,
endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self, KvSchedulerError> {
......@@ -299,7 +299,7 @@ impl WorkerSelector for DefaultWorkerSelector {
&self,
workers: &ProcessedEndpoints,
request: &SchedulingRequest,
block_size: usize,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
......@@ -307,7 +307,7 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints);
}
let request_blocks = request.isl_tokens.div_ceil(block_size);
let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
let mut worker_logits = HashMap::new();
// Calculate logits for each worker
......
......@@ -15,6 +15,7 @@ pub mod common;
pub mod disagg_router;
pub mod discovery;
pub mod engines;
pub mod entrypoint;
pub mod gguf;
pub mod http;
pub mod hub;
......
......@@ -5,6 +5,9 @@ use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Context as _;
use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{
component::{Component, Endpoint},
......@@ -12,8 +15,10 @@ use dynamo_runtime::{
};
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType;
use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
......@@ -25,58 +30,85 @@ const HF_SCHEME: &str = "hf://";
/// is invisible, for example in a text chat.
const DEFAULT_NAME: &str = "dynamo";
#[derive(Debug, Clone)]
pub struct LocalModel {
full_path: PathBuf,
card: ModelDeploymentCard,
/// Engines don't usually provide a default, so we do.
const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;
/// We can't have it default to 0, so pick something
const DEFAULT_HTTP_PORT: u16 = 8080;
pub struct LocalModelBuilder {
model_path: Option<PathBuf>,
model_name: Option<String>,
model_config: Option<PathBuf>,
endpoint_id: Option<EndpointId>,
context_length: Option<u32>,
template_file: Option<PathBuf>,
router_config: Option<RouterConfig>,
kv_cache_block_size: u32,
http_port: u16,
}
impl Default for LocalModel {
impl Default for LocalModelBuilder {
fn default() -> Self {
LocalModel {
full_path: PathBuf::new(),
card: ModelDeploymentCard::with_name_only(DEFAULT_NAME),
LocalModelBuilder {
kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
http_port: DEFAULT_HTTP_PORT,
model_path: Default::default(),
model_name: Default::default(),
model_config: Default::default(),
endpoint_id: Default::default(),
context_length: Default::default(),
template_file: Default::default(),
router_config: Default::default(),
}
}
}
impl LocalModel {
pub fn with_name_only(name: &str) -> Self {
LocalModel {
card: ModelDeploymentCard::with_name_only(name),
..Default::default()
impl LocalModelBuilder {
pub fn model_path(&mut self, model_path: Option<PathBuf>) -> &mut Self {
self.model_path = model_path;
self
}
pub fn model_name(&mut self, model_name: Option<String>) -> &mut Self {
self.model_name = model_name;
self
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
pub fn model_config(&mut self, model_config: Option<PathBuf>) -> &mut Self {
self.model_config = model_config;
self
}
pub fn path(&self) -> &Path {
&self.full_path
pub fn endpoint_id(&mut self, endpoint_id: EndpointId) -> &mut Self {
self.endpoint_id = Some(endpoint_id);
self
}
pub fn display_name(&self) -> &str {
&self.card.display_name
pub fn context_length(&mut self, context_length: Option<u32>) -> &mut Self {
self.context_length = context_length;
self
}
pub fn service_name(&self) -> &str {
&self.card.service_name
/// Passing None resets it to default
pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option<u32>) -> &mut Self {
self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE);
self
}
pub fn is_gguf(&self) -> bool {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension
// We will error when we come to parse it
self.full_path.is_file()
pub fn http_port(&mut self, port: u16) -> &mut Self {
self.http_port = port;
self
}
/// Override max number of tokens in context. We usually only do this to limit kv cache allocation.
pub fn set_context_length(&mut self, context_length: usize) {
self.card.context_length = context_length;
pub fn router_config(&mut self, router_config: RouterConfig) -> &mut Self {
self.router_config = Some(router_config);
self
}
pub fn set_kv_cache_block_size(&mut self, block_size: usize) {
self.card.kv_cache_block_size = block_size;
pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
self.template_file = template_file;
self
}
/// Make an LLM ready for use:
......@@ -88,28 +120,60 @@ impl LocalModel {
/// The model name will depend on what "model_path" is:
/// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct"
/// - A file: The GGUF filename: "/data/llms/Qwen2.5-3B-Instruct-Q6_K.gguf" -> "Qwen2.5-3B-Instruct-Q6_K.gguf"
/// - An HF repo: The HF repo name: "Qwen/Qwen2.5-3B-Instruct" stays the same
pub async fn prepare(
model_path: &str,
override_config: Option<&Path>,
override_name: Option<String>,
) -> anyhow::Result<LocalModel> {
// Name it
/// - An HF repo: The HF repo name: "Qwen/Qwen3-0.6B" stays the same
pub async fn build(&mut self) -> anyhow::Result<LocalModel> {
// Generate an endpoint ID for this model if the user didn't provide one.
// The user only provides one if exposing the model.
let endpoint_id = self
.endpoint_id
.take()
.unwrap_or_else(|| internal_endpoint("local_model"));
let template = self
.template_file
.as_deref()
.map(RequestTemplate::load)
.transpose()?;
// echo_full engine doesn't need a path. It's an edge case, move it out of the way.
if self.model_path.is_none() {
return Ok(LocalModel {
card: ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
),
full_path: PathBuf::new(),
endpoint_id,
template,
http_port: self.http_port,
// We always have one. The Option is so we can take it.
router_config: self
.router_config
.take()
.expect("unreachable, RouterConfig missing"),
});
}
// Main logic. We are running a model.
let model_path = self.model_path.take().unwrap();
let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?;
// Check for hf:// prefix first, in case we really want an HF repo but it conflicts
// with a relative path.
let is_hf_repo =
model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false);
let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo {
// HF download if necessary
super::hub::from_hf(relative_path).await?
} else {
fs::canonicalize(relative_path)?
};
// --model-config takes precedence over --model-path
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = ModelDeploymentCard::load(&model_config_path).await?;
let model_name = override_name.unwrap_or_else(|| {
// Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| {
if is_hf_repo {
// HF repos use their full name ("org/name") not the folder name
relative_path.to_string()
......@@ -124,15 +188,83 @@ impl LocalModel {
})
}
});
card.set_name(&model_name);
// Load the ModelDeploymentCard
card.kv_cache_block_size = self.kv_cache_block_size;
// --model-config takes precedence over --model-path
let model_config_path = override_config.unwrap_or(&full_path);
let mut card = ModelDeploymentCard::load(&model_config_path).await?;
card.set_name(&model_name);
// Override max number of tokens in context. We usually only do this to limit kv cache allocation.
if let Some(context_length) = self.context_length {
card.context_length = context_length;
}
Ok(LocalModel { full_path, card })
Ok(LocalModel {
card,
full_path,
endpoint_id,
template,
http_port: self.http_port,
router_config: self
.router_config
.take()
.expect("unreachable, RouterConfig missing"),
})
}
}
#[derive(Debug, Clone)]
pub struct LocalModel {
full_path: PathBuf,
card: ModelDeploymentCard,
endpoint_id: EndpointId,
template: Option<RequestTemplate>,
http_port: u16, // Only used if input is HTTP server
router_config: RouterConfig,
}
impl LocalModel {
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn path(&self) -> &Path {
&self.full_path
}
pub fn display_name(&self) -> &str {
&self.card.display_name
}
pub fn service_name(&self) -> &str {
&self.card.service_name
}
pub fn request_template(&self) -> Option<RequestTemplate> {
self.template.clone()
}
pub fn http_port(&self) -> u16 {
self.http_port
}
pub fn router_config(&self) -> &RouterConfig {
&self.router_config
}
pub fn is_gguf(&self) -> bool {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension
// We will error when we come to parse it
self.full_path.is_file()
}
/// An endpoint to identify this model by.
pub fn endpoint_id(&self) -> &EndpointId {
&self.endpoint_id
}
/// Drop the LocalModel returning it's ModelDeploymentCard.
/// For the case where we only need the card and don't want to clone it.
pub fn into_card(self) -> ModelDeploymentCard {
self.card
}
/// Attach this model the endpoint. This registers it on the network
......@@ -202,3 +334,13 @@ impl LocalModel {
Ok(())
}
}
/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
EndpointId {
namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
component: engine.to_string(),
name: "generate".to_string(),
}
}
......@@ -57,7 +57,7 @@ pub struct KvManager {
max_capacity: usize,
#[getter(copy)]
block_size: usize,
block_size: u32,
active_blocks: HashMap<UniqueBlock, usize>,
......@@ -67,7 +67,7 @@ pub struct KvManager {
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
pub fn new(max_capacity: usize, block_size: u32) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
......@@ -245,7 +245,7 @@ impl KvManager {
let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
......@@ -261,7 +261,7 @@ impl KvManager {
// Calculate prefill compute
let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64;
new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64;
Some(PrefillCost {
new_tokens,
......
......@@ -193,7 +193,7 @@ impl Scheduler {
pub fn new(
kv_capacity: usize,
watermark: f64,
block_size: usize,
block_size: u32,
chunk_size: Option<usize>,
output_tx: Option<mpsc::Sender<Uuid>>,
cancellation_token: Option<CancellationToken>,
......@@ -272,7 +272,7 @@ impl Scheduler {
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding (assumed memory bound on KV cache)
let active_tokens = kv_manager_guard.num_active_blocks() * block_size;
let active_tokens = kv_manager_guard.num_active_blocks() * (block_size as usize);
// TODO: 2 is a dummy / magic scaling factor
let mut generation_time = Duration::from_micros((active_tokens / 2) as u64);
......@@ -406,7 +406,7 @@ impl Scheduler {
}
/// Convert a Request to an ActiveSequence
fn get_active_sequence(request: Request, block_size: usize, chunk_size: usize) -> ActiveSequence {
fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> ActiveSequence {
if let Request::Active(active_seq) = request {
return active_seq;
}
......@@ -475,7 +475,7 @@ mod tests {
let kv_capacity: usize = 500;
let watermark: f64 = 0.01; // 1% watermark
let block_size: usize = 64;
let block_size: u32 = 64;
let chunk_size: usize = 256;
let num_requests: usize = 100;
let input_len: usize = 1000;
......
......@@ -23,7 +23,7 @@ use uuid;
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
block_size: u32,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
......@@ -32,7 +32,7 @@ fn create_unique_blocks_from_sequence(
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
if tokens.total_tokens() % (block_size as usize) != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
......@@ -50,7 +50,7 @@ pub struct ActiveSequence {
tokens: TokenBlockSequence,
#[getter(copy)]
block_size: usize,
block_size: u32,
#[getter(copy)]
chunk_size: usize, // TODO: not actually used
......@@ -72,7 +72,7 @@ impl ActiveSequence {
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
block_size: Option<u32>,
chunk_size: Option<usize>,
) -> Self {
let block_size = block_size.unwrap_or(64);
......@@ -96,8 +96,8 @@ impl ActiveSequence {
}
}
pub fn extra_tokens(&self) -> usize {
self.len() % self.block_size
pub fn extra_tokens(&self) -> u32 {
(self.len() % self.block_size as usize) as u32
}
pub fn len(&self) -> usize {
......@@ -112,7 +112,7 @@ impl ActiveSequence {
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
block_size: Option<u32>,
chunk_size: Option<usize>,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
......@@ -125,7 +125,7 @@ impl ActiveSequence {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % self.block_size != 1 {
if self.len() % (self.block_size as usize) != 1 {
return None;
}
......@@ -223,7 +223,7 @@ impl ActiveSequence {
self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block
if self.tokens.total_tokens() % self.block_size == 0 {
if self.tokens.total_tokens() % (self.block_size as usize) == 0 {
self.unique_blocks.pop();
}
}
......@@ -285,7 +285,7 @@ mod tests {
// Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % seq1.block_size(), 1);
assert_eq!(seq1.len() % (seq1.block_size() as usize), 1);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect();
......@@ -335,12 +335,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
);
assert_eq!(
seq1.len() % seq1.block_size(),
seq1.len() % (seq1.block_size() as usize),
1,
"seq1 should have 1 partial token"
);
assert_eq!(
seq2.len() % seq2.block_size(),
seq2.len() % (seq2.block_size() as usize),
1,
"seq2 should have 1 partial token"
);
......
......@@ -76,7 +76,7 @@ impl ModelDeploymentCard {
let content = super::model::load_gguf(gguf_file)?;
let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
.to_u32()
.unwrap_or(0) as usize;
.unwrap_or(0);
tracing::debug!(context_length, "Loaded context length from GGUF");
Ok(Self {
......
......@@ -117,11 +117,11 @@ pub struct ModelDeploymentCard {
pub revision: u64,
/// Max context (in number of tokens) this model can handle
pub context_length: usize,
pub context_length: u32,
/// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router.
pub kv_cache_block_size: usize,
pub kv_cache_block_size: u32,
}
impl ModelDeploymentCard {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment