"components/backends/vscode:/vscode.git/clone" did not exist on "e5b6a054b51bf21658919037e8580133cbdf3fae"
Unverified Commit 92f06b0e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(dynamo-run): Refactor to library (#1687)

Move much of what was in the `dynamo-run` crate into `dynamo-llm` so that everyone can use it.

Example usage:

1. Create a `LocalModel`:

```
    let local_model = LocalModelBuilder::default()
	.model_path("Qwen/Qwen3-0.6B")
	.http_port(8080)
	.build().await?;
```

2. Make an engine:

```
    let engine_config = EngineConfig::StaticFull {
	engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
	model: Box::new(local_model),
    };
```

3. Connect it to an input and run it

```
    dynamo_llm::entrypoint::input::run_input(Input::Http, runtime, engine_config).await?;
```

For https://github.com/ai-dynamo/dynamo/issues/1647

Code Rabbit summary, thanks:
  * Introduced a flexible builder pattern for local model configuration, allowing advanced customization and easier initialization.
  * Added new input modes and unified input handling, supporting interactive chat, HTTP server, batch file, and distributed endpoint modes.
  * Centralized engine configuration and routing, enabling more extensible and maintainable engine management.
  * Simplified and modularized the codebase by moving input and engine logic into dedicated modules.
  * Replaced direct model construction with an asynchronous builder for improved clarity and extensibility.
  * Streamlined configuration and validation for flags and router settings.
  * Added validation to prevent incompatible input and output combinations in endpoint and dynamic modes.
parent 3b62692f
...@@ -180,7 +180,7 @@ impl ModelManager { ...@@ -180,7 +180,7 @@ impl ModelManager {
&self, &self,
model_name: &str, model_name: &str,
component: &Component, component: &Component,
kv_cache_block_size: usize, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) { if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
...@@ -209,7 +209,7 @@ impl ModelManager { ...@@ -209,7 +209,7 @@ impl ModelManager {
&self, &self,
model_name: &str, model_name: &str,
component: &Component, component: &Component,
kv_cache_block_size: usize, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The entrypoint module provides tools to build a Dynamo runner.
//! - Create an EngineConfig of the engine (potentially auto-discovered) to execute
//! - Connect it to an Input
pub mod input;
use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode;
use crate::{
backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig,
local_model::LocalModel,
};
#[derive(Debug, Clone, Default)]
pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
}
impl RouterConfig {
pub fn new(router_mode: RouterMode, kv_router_config: KvRouterConfig) -> Self {
Self {
router_mode,
kv_router_config,
}
}
}
pub enum EngineConfig {
/// Remote networked engines
Dynamic(Box<LocalModel>),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
engine: Arc<dyn StreamingEngine>,
model: Box<LocalModel>,
},
/// A core engine expects to be wrapped with pre/post processors that handle tokenization.
StaticCore {
engine: ExecutionContext,
model: Box<LocalModel>,
},
}
impl EngineConfig {
fn local_model(&self) -> &LocalModel {
use EngineConfig::*;
match self {
Dynamic(lm) => lm,
StaticFull { model, .. } => model,
StaticCore { model, .. } => model,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! This module contains tools to gather a prompt from a user, forward it to an engine and return
//! the response.
//! See the Input enum for the inputs available. Input::Http (OpenAI compatible HTTP server)
//! and Input::Text (interactive chat) are good places to start.
//! The main entry point is `run_input`.
use std::{
fmt,
io::{IsTerminal as _, Read as _},
path::PathBuf,
};
pub mod batch;
mod common;
pub mod endpoint;
pub mod http;
pub mod text;
use dynamo_runtime::{protocols::ENDPOINT_SCHEME, DistributedRuntime};
const BATCH_PREFIX: &str = "batch:";
/// The various ways of connecting prompts to an engine
#[derive(PartialEq)]
pub enum Input {
/// Run an OpenAI compatible HTTP server
Http,
/// Single prompt on stdin
Stdin,
/// Interactive chat
Text,
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
/// Batch mode. Run all the prompts, write the outputs, exit.
Batch(PathBuf),
}
impl TryFrom<&str> for Input {
type Error = anyhow::Error;
fn try_from(s: &str) -> anyhow::Result<Self> {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"stdin" => Ok(Input::Stdin),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
Ok(Input::Endpoint(endpoint_path.to_string()))
}
batch_patch if batch_patch.starts_with(BATCH_PREFIX) => {
let path = batch_patch.strip_prefix(BATCH_PREFIX).unwrap();
Ok(Input::Batch(PathBuf::from(path)))
}
e => Err(anyhow::anyhow!("Invalid in= option '{e}'")),
}
}
}
impl fmt::Display for Input {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
Input::Http => "http",
Input::Text => "text",
Input::Stdin => "stdin",
Input::Endpoint(path) => path,
Input::Batch(path) => &path.display().to_string(),
};
write!(f, "{s}")
}
}
impl Default for Input {
fn default() -> Self {
if std::io::stdin().is_terminal() {
Input::Text
} else {
Input::Stdin
}
}
}
/// Run the given engine (EngineConfig) connected to an input.
/// Does not return until the input exits.
pub async fn run_input(
in_opt: Input,
runtime: dynamo_runtime::Runtime,
engine_config: super::EngineConfig,
) -> anyhow::Result<()> {
match in_opt {
Input::Http => {
http::run(runtime.clone(), engine_config).await?;
}
Input::Text => {
text::run(runtime.clone(), None, engine_config).await?;
}
Input::Stdin => {
let mut prompt = String::new();
std::io::stdin().read_to_string(&mut prompt).unwrap();
text::run(runtime.clone(), Some(prompt), engine_config).await?;
}
Input::Batch(path) => {
batch::run(runtime.clone(), path, engine_config).await?;
}
Input::Endpoint(path) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
endpoint::run(distributed_runtime, path, engine_config).await?;
}
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _; use crate::preprocessor::OpenAIPreprocessor;
use async_openai::types::FinishReason; use crate::request_template::RequestTemplate;
use dynamo_llm::model_card::model::ModelDeploymentCard; use crate::types::openai::chat_completions::{
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::request_template::RequestTemplate;
use dynamo_llm::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime}; use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt; use futures::StreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -19,8 +18,8 @@ use std::sync::Arc; ...@@ -19,8 +18,8 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use crate::input::common; use crate::entrypoint::input::common;
use crate::{EngineConfig, Flags}; use crate::entrypoint::EngineConfig;
/// Max tokens in each response. /// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model /// TODO: For batch mode this should be the full context size of the model
...@@ -53,11 +52,8 @@ struct Entry { ...@@ -53,11 +52,8 @@ struct Entry {
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
_flags: Flags,
card: ModelDeploymentCard,
input_jsonl: PathBuf, input_jsonl: PathBuf,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
// Check if the path exists and is a directory // Check if the path exists and is a directory
...@@ -68,11 +64,10 @@ pub async fn run( ...@@ -68,11 +64,10 @@ pub async fn run(
); );
} }
let prepared_engine = common::prepare_engine(runtime, engine_config).await?; let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let service_name_ref = Arc::new(prepared_engine.service_name);
let pre_processor = if card.has_tokenizer() { let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(card).await?) Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?)
} else { } else {
None None
}; };
...@@ -85,6 +80,7 @@ pub async fn run( ...@@ -85,6 +80,7 @@ pub async fn run(
tracing::error!(%err, "Failed writing output to {}", output_file.display()); tracing::error!(%err, "Failed writing output to {}", output_file.display());
} }
}); });
let service_name_ref = Arc::new(prepared_engine.service_name);
let tokens_in = Arc::new(AtomicU64::new(0)); let tokens_in = Arc::new(AtomicU64::new(0));
let tokens_out = Arc::new(AtomicU64::new(0)); let tokens_out = Arc::new(AtomicU64::new(0));
...@@ -98,7 +94,7 @@ pub async fn run( ...@@ -98,7 +94,7 @@ pub async fn run(
tracing::info!("Timer start."); tracing::info!("Timer start.");
let start = Instant::now(); let start = Instant::now();
let mut lines = buffered_input.lines(); let mut lines = buffered_input.lines();
let template: Option<Arc<RequestTemplate>> = template.map(Arc::new); let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
while let Ok(Some(line)) = lines.next_line().await { while let Ok(Some(line)) = lines.next_line().await {
if cancel_token.is_cancelled() { if cancel_token.is_cancelled() {
break; break;
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
use std::pin::Pin; use std::pin::Pin;
use dynamo_llm::{ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH}, discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::EngineConfig,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
request_template::RequestTemplate,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
...@@ -25,12 +27,22 @@ use dynamo_runtime::{ ...@@ -25,12 +27,22 @@ use dynamo_runtime::{
}; };
use std::sync::Arc; use std::sync::Arc;
use crate::EngineConfig;
pub struct PreparedEngine { pub struct PreparedEngine {
pub service_name: String, pub service_name: String,
pub engine: OpenAIChatCompletionsStreamingEngine, pub engine: OpenAIChatCompletionsStreamingEngine,
pub inspect_template: bool, pub inspect_template: bool,
pub card: Option<ModelDeploymentCard>,
pub request_template: Option<RequestTemplate>,
}
impl PreparedEngine {
pub fn has_tokenizer(&self) -> bool {
if let Some(card) = self.card.as_ref() {
card.has_tokenizer()
} else {
false
}
}
} }
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
...@@ -39,7 +51,7 @@ pub async fn prepare_engine( ...@@ -39,7 +51,7 @@ pub async fn prepare_engine(
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic => { EngineConfig::Dynamic(local_model) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let Some(etcd_client) = distributed_runtime.etcd_client() else { let Some(etcd_client) = distributed_runtime.etcd_client() else {
...@@ -71,6 +83,8 @@ pub async fn prepare_engine( ...@@ -71,6 +83,8 @@ pub async fn prepare_engine(
service_name: model_service_name, service_name: model_service_name,
engine, engine,
inspect_template: false, inspect_template: false,
card: None,
request_template: local_model.request_template(),
}) })
} }
EngineConfig::StaticFull { engine, model } => { EngineConfig::StaticFull { engine, model } => {
...@@ -81,6 +95,8 @@ pub async fn prepare_engine( ...@@ -81,6 +95,8 @@ pub async fn prepare_engine(
service_name, service_name,
engine, engine,
inspect_template: false, inspect_template: false,
request_template: model.request_template(),
card: Some(model.into_card()),
}) })
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -99,6 +115,8 @@ pub async fn prepare_engine( ...@@ -99,6 +115,8 @@ pub async fn prepare_engine(
service_name, service_name,
engine: pipeline, engine: pipeline,
inspect_template: true, inspect_template: true,
request_template: model.request_template(),
card: Some(model.into_card()),
}) })
} }
} }
...@@ -137,21 +155,21 @@ where ...@@ -137,21 +155,21 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use dynamo_llm::types::openai::{ use crate::types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}; };
const HF_PATH: &str = concat!( const HF_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"), env!("CARGO_MANIFEST_DIR"),
"/../../lib/llm/tests/data/sample-models/mock-llama-3.1-8b-instruct" "/tests/data/sample-models/mock-llama-3.1-8b-instruct"
); );
#[tokio::test] #[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?; let card = ModelDeploymentCard::load(HF_PATH).await?;
let engine = dynamo_llm::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions // Build pipeline for chat completions
let pipeline = build_pipeline::< let pipeline = build_pipeline::<
...@@ -170,7 +188,7 @@ mod tests { ...@@ -170,7 +188,7 @@ mod tests {
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH).await?; let card = ModelDeploymentCard::load(HF_PATH).await?;
let engine = dynamo_llm::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for completions // Build pipeline for completions
let pipeline = let pipeline =
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{future::Future, pin::Pin, sync::Arc}; use std::{future::Future, pin::Pin, sync::Arc};
use dynamo_llm::{ use crate::{
backend::Backend, backend::Backend,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
model_type::ModelType, model_type::ModelType,
...@@ -33,7 +21,7 @@ use dynamo_runtime::pipeline::{ ...@@ -33,7 +21,7 @@ use dynamo_runtime::pipeline::{
}; };
use dynamo_runtime::{protocols::Endpoint as EndpointId, DistributedRuntime}; use dynamo_runtime::{protocols::Endpoint as EndpointId, DistributedRuntime};
use crate::EngineConfig; use crate::entrypoint::EngineConfig;
pub async fn run( pub async fn run(
distributed_runtime: DistributedRuntime, distributed_runtime: DistributedRuntime,
...@@ -91,7 +79,7 @@ pub async fn run( ...@@ -91,7 +79,7 @@ pub async fn run(
(Box::pin(fut), Some(model.card().clone())) (Box::pin(fut), Some(model.card().clone()))
} }
EngineConfig::Dynamic => { EngineConfig::Dynamic(_) => {
// We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a // We can only get here for in=dyn out=vllm|sglang`, because vllm and sglang are a
// subprocess that we talk to like a remote endpoint. // subprocess that we talk to like a remote endpoint.
// That means the vllm/sglang subprocess is doing all the work, we are idle. // That means the vllm/sglang subprocess is doing all the work, we are idle.
......
...@@ -3,14 +3,12 @@ ...@@ -3,14 +3,12 @@
use std::sync::Arc; use std::sync::Arc;
use crate::input::common; use crate::{
use crate::{EngineConfig, Flags};
use dynamo_llm::kv_router::KvRouterConfig;
use dynamo_llm::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH}, discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{input::common, EngineConfig},
http::service::service_v2, http::service::service_v2,
request_template::RequestTemplate, kv_router::KvRouterConfig,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
...@@ -23,32 +21,28 @@ use dynamo_runtime::transports::etcd; ...@@ -23,32 +21,28 @@ use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
/// Build and run an HTTP service /// Build and run an HTTP service
pub async fn run( pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
runtime: Runtime,
flags: Flags,
engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> {
let http_service = service_v2::HttpService::builder() let http_service = service_v2::HttpService::builder()
.port(flags.http_port) .port(engine_config.local_model().http_port())
.enable_chat_endpoints(true) .enable_chat_endpoints(true)
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true) .enable_embeddings_endpoints(true)
.with_request_template(template) .with_request_template(engine_config.local_model().request_template())
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic => { EngineConfig::Dynamic(_) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
match distributed_runtime.etcd_client() { match distributed_runtime.etcd_client() {
Some(etcd_client) => { Some(etcd_client) => {
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
run_watcher( run_watcher(
distributed_runtime, distributed_runtime,
http_service.state().manager_clone(), http_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, MODEL_ROOT_PATH,
flags.router_mode.into(), router_config.router_mode,
Some(flags.kv_router_config()), Some(router_config.kv_router_config.clone()),
) )
.await?; .await?;
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::protocols::openai::nvext::NvExt; use crate::protocols::openai::nvext::NvExt;
use dynamo_llm::types::openai::chat_completions::{ use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime}; use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt; use futures::StreamExt;
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
use crate::input::common; use crate::entrypoint::input::common;
use crate::{EngineConfig, Flags, RequestTemplate}; use crate::entrypoint::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this /// TODO: Cmd line flag to overwrite this
...@@ -18,20 +19,19 @@ const MAX_TOKENS: u32 = 8192; ...@@ -18,20 +19,19 @@ const MAX_TOKENS: u32 = 8192;
pub async fn run( pub async fn run(
runtime: Runtime, runtime: Runtime,
_flags: Flags,
single_prompt: Option<String>, single_prompt: Option<String>,
engine_config: EngineConfig, engine_config: EngineConfig,
template: Option<RequestTemplate>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
let prepared_engine = common::prepare_engine(runtime, engine_config).await?; let prepared_engine = common::prepare_engine(runtime, engine_config).await?;
// TODO: Pass prepared_engine directly
main_loop( main_loop(
cancel_token, cancel_token,
&prepared_engine.service_name, &prepared_engine.service_name,
prepared_engine.engine, prepared_engine.engine,
single_prompt, single_prompt,
prepared_engine.inspect_template, prepared_engine.inspect_template,
template, prepared_engine.request_template,
) )
.await .await
} }
......
...@@ -50,7 +50,7 @@ pub trait WorkerSelector { ...@@ -50,7 +50,7 @@ pub trait WorkerSelector {
&self, &self,
workers: &ProcessedEndpoints, workers: &ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: usize, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
} }
...@@ -104,13 +104,13 @@ impl KvRouterConfig { ...@@ -104,13 +104,13 @@ impl KvRouterConfig {
pub struct KvRouter { pub struct KvRouter {
indexer: KvIndexer, indexer: KvIndexer,
scheduler: KvScheduler, scheduler: KvScheduler,
block_size: usize, block_size: u32,
} }
impl KvRouter { impl KvRouter {
pub async fn new( pub async fn new(
component: Component, component: Component,
block_size: usize, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self> { ) -> Result<Self> {
let cancellation_token = component let cancellation_token = component
...@@ -196,7 +196,7 @@ impl KvRouter { ...@@ -196,7 +196,7 @@ impl KvRouter {
} }
/// Get the block size this router was configured with /// Get the block size this router was configured with
pub fn block_size(&self) -> usize { pub fn block_size(&self) -> u32 {
self.block_size self.block_size
} }
} }
......
...@@ -119,9 +119,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -119,9 +119,9 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
/// ### Returns /// ### Returns
/// ///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens. /// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: usize) -> Vec<LocalBlockHash> { pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> {
tokens tokens
.chunks_exact(kv_block_size) // Split into chunks of kv_block_size elements .chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements
.map(|chunk| { .map(|chunk| {
let bytes: Vec<u8> = chunk let bytes: Vec<u8> = chunk
.iter() .iter()
...@@ -527,7 +527,7 @@ pub struct KvIndexer { ...@@ -527,7 +527,7 @@ pub struct KvIndexer {
/// A handle to the background task managing the KV store. /// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>, task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle. /// The size of the KV block this indexer can handle.
kv_block_size: usize, kv_block_size: u32,
} }
impl KvIndexer { impl KvIndexer {
...@@ -544,7 +544,7 @@ impl KvIndexer { ...@@ -544,7 +544,7 @@ impl KvIndexer {
pub fn new_with_frequency( pub fn new_with_frequency(
token: CancellationToken, token: CancellationToken,
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
kv_block_size: usize, kv_block_size: u32,
) -> Self { ) -> Self {
let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048); let (event_tx, event_rx) = mpsc::channel::<RouterEvent>(2048);
let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128); let (match_tx, match_rx) = mpsc::channel::<MatchRequest>(128);
...@@ -611,11 +611,11 @@ impl KvIndexer { ...@@ -611,11 +611,11 @@ impl KvIndexer {
} }
} }
pub fn block_size(&self) -> usize { pub fn block_size(&self) -> u32 {
self.kv_block_size self.kv_block_size
} }
pub fn new(token: CancellationToken, kv_block_size: usize) -> Self { pub fn new(token: CancellationToken, kv_block_size: u32) -> Self {
Self::new_with_frequency(token, None, kv_block_size) Self::new_with_frequency(token, None, kv_block_size)
} }
...@@ -697,7 +697,7 @@ pub struct KvIndexerSharded { ...@@ -697,7 +697,7 @@ pub struct KvIndexerSharded {
/// A `CancellationToken` for managing shutdown. /// A `CancellationToken` for managing shutdown.
cancel: CancellationToken, cancel: CancellationToken,
/// The size of the KV block this indexer can handle. /// The size of the KV block this indexer can handle.
kv_block_size: usize, kv_block_size: u32,
worker_assignments: HashMap<WorkerId, usize>, worker_assignments: HashMap<WorkerId, usize>,
worker_counts: Vec<usize>, worker_counts: Vec<usize>,
...@@ -723,7 +723,7 @@ impl KvIndexerSharded { ...@@ -723,7 +723,7 @@ impl KvIndexerSharded {
token: CancellationToken, token: CancellationToken,
num_shards: usize, num_shards: usize,
expiration_duration: Option<Duration>, expiration_duration: Option<Duration>,
kv_block_size: usize, kv_block_size: u32,
) -> Self { ) -> Self {
let worker_assignments: HashMap<WorkerId, usize> = HashMap::new(); let worker_assignments: HashMap<WorkerId, usize> = HashMap::new();
let worker_counts: Vec<usize> = vec![0; num_shards]; let worker_counts: Vec<usize> = vec![0; num_shards];
...@@ -802,11 +802,11 @@ impl KvIndexerSharded { ...@@ -802,11 +802,11 @@ impl KvIndexerSharded {
} }
} }
pub fn block_size(&self) -> usize { pub fn block_size(&self) -> u32 {
self.kv_block_size self.kv_block_size
} }
pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: usize) -> Self { pub fn new(token: CancellationToken, num_shards: usize, kv_block_size: u32) -> Self {
Self::new_with_frequency(token, num_shards, None, kv_block_size) Self::new_with_frequency(token, num_shards, None, kv_block_size)
} }
} }
...@@ -1312,24 +1312,20 @@ mod tests { ...@@ -1312,24 +1312,20 @@ mod tests {
#[case(11)] #[case(11)]
#[case(32)] #[case(32)]
#[case(64)] #[case(64)]
fn test_compute_block_hash_for_seq(#[case] kv_block_size: usize) { fn test_compute_block_hash_for_seq(#[case] kv_block_size: u32) {
setup(); setup();
// create a sequence of 64 elements // create a sequence of 64 elements
let sequence = (0..kv_block_size).map(|i| i as u32).collect::<Vec<u32>>(); let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements // create a sequence of 65 elements
let sequence = (0..(kv_block_size + 1)) let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements // create a sequence of 129 elements
let sequence = (0..(2 * kv_block_size + 1)) let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
.map(|i| i as u32)
.collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size);
assert_eq!(hashes.len(), 2); assert_eq!(hashes.len(), 2);
} }
...@@ -1337,7 +1333,7 @@ mod tests { ...@@ -1337,7 +1333,7 @@ mod tests {
fn make_indexer( fn make_indexer(
token: &CancellationToken, token: &CancellationToken,
num_shards: usize, num_shards: usize,
kv_block_size: usize, kv_block_size: u32,
) -> Box<dyn KvIndexerInterface> { ) -> Box<dyn KvIndexerInterface> {
if num_shards == 1 { if num_shards == 1 {
Box::new(KvIndexer::new(token.clone(), kv_block_size)) Box::new(KvIndexer::new(token.clone(), kv_block_size))
...@@ -1360,7 +1356,7 @@ mod tests { ...@@ -1360,7 +1356,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_kv_indexer_new(num_shards: usize, kv_block_size: usize) { async fn test_kv_indexer_new(num_shards: usize, kv_block_size: u32) {
setup(); setup();
let token: CancellationToken = CancellationToken::new(); let token: CancellationToken = CancellationToken::new();
let _ = make_indexer(&token, num_shards, kv_block_size); let _ = make_indexer(&token, num_shards, kv_block_size);
...@@ -1368,7 +1364,7 @@ mod tests { ...@@ -1368,7 +1364,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_find_matches(num_shards: usize, kv_block_size: usize) { async fn test_find_matches(num_shards: usize, kv_block_size: u32) {
setup(); setup();
let token = CancellationToken::new(); let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size); let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
...@@ -1381,7 +1377,7 @@ mod tests { ...@@ -1381,7 +1377,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_find_matches_for_request(num_shards: usize, kv_block_size: usize) { async fn test_find_matches_for_request(num_shards: usize, kv_block_size: u32) {
setup(); setup();
let token = CancellationToken::new(); let token = CancellationToken::new();
let kv_indexer = make_indexer(&token, num_shards, kv_block_size); let kv_indexer = make_indexer(&token, num_shards, kv_block_size);
...@@ -1394,7 +1390,7 @@ mod tests { ...@@ -1394,7 +1390,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_apply_event(num_shards: usize, kv_block_size: usize) { async fn test_apply_event(num_shards: usize, kv_block_size: u32) {
setup(); setup();
let worker_id = 0; let worker_id = 0;
...@@ -1409,7 +1405,7 @@ mod tests { ...@@ -1409,7 +1405,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_shutdown(num_shards: usize, kv_block_size: usize) { async fn test_shutdown(num_shards: usize, kv_block_size: u32) {
setup(); setup();
let token = CancellationToken::new(); let token = CancellationToken::new();
let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size); let mut kv_indexer = make_indexer(&token, num_shards, kv_block_size);
...@@ -1419,7 +1415,7 @@ mod tests { ...@@ -1419,7 +1415,7 @@ mod tests {
#[tokio::test] #[tokio::test]
#[apply(indexer_template)] #[apply(indexer_template)]
async fn test_frequency(num_shards: usize, kv_block_size: usize) { async fn test_frequency(num_shards: usize, kv_block_size: u32) {
const ONE_MILLIS: Duration = Duration::from_millis(1); const ONE_MILLIS: Duration = Duration::from_millis(1);
setup(); setup();
......
...@@ -62,7 +62,7 @@ impl KvEventSource { ...@@ -62,7 +62,7 @@ impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`]. /// Start the event source from a [`KvEventSourceConfig`].
fn start( fn start(
component: Component, component: Component,
kv_block_size: usize, kv_block_size: u32,
source_config: KvEventSourceConfig, source_config: KvEventSourceConfig,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
...@@ -98,7 +98,7 @@ impl KvEventSource { ...@@ -98,7 +98,7 @@ impl KvEventSource {
/// A publisher of KV events. /// A publisher of KV events.
pub struct KvEventPublisher { pub struct KvEventPublisher {
/// The size of the KV block. /// The size of the KV block.
kv_block_size: usize, kv_block_size: u32,
/// The source of KV events. /// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`]. /// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source: Option<KvEventSource>, source: Option<KvEventSource>,
...@@ -112,7 +112,7 @@ impl KvEventPublisher { ...@@ -112,7 +112,7 @@ impl KvEventPublisher {
pub fn new( pub fn new(
component: Component, component: Component,
worker_id: i64, worker_id: i64,
kv_block_size: usize, kv_block_size: u32,
source_config: Option<KvEventSourceConfig>, source_config: Option<KvEventSourceConfig>,
) -> Result<Self> { ) -> Result<Self> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
...@@ -155,7 +155,7 @@ impl KvEventPublisher { ...@@ -155,7 +155,7 @@ impl KvEventPublisher {
self.tx.send(event) self.tx.send(event)
} }
pub fn kv_block_size(&self) -> usize { pub fn kv_block_size(&self) -> u32 {
self.kv_block_size self.kv_block_size
} }
...@@ -223,7 +223,7 @@ pub async fn start_zmq_listener( ...@@ -223,7 +223,7 @@ pub async fn start_zmq_listener(
zmq_topic: String, zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
kv_block_size: usize, kv_block_size: u32,
) { ) {
tracing::debug!( tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')", "KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
...@@ -335,7 +335,7 @@ pub async fn start_zmq_listener( ...@@ -335,7 +335,7 @@ pub async fn start_zmq_listener(
fn convert_event( fn convert_event(
raw: RawKvEvent, raw: RawKvEvent,
event_id: u64, event_id: u64,
kv_block_size: usize, kv_block_size: u32,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent { ) -> KvCacheEvent {
match raw { match raw {
...@@ -382,7 +382,7 @@ fn convert_event( ...@@ -382,7 +382,7 @@ fn convert_event(
} }
pub fn create_stored_block_from_parts( pub fn create_stored_block_from_parts(
kv_block_size: usize, kv_block_size: u32,
block_hash: i64, block_hash: i64,
token_ids: &[u32], token_ids: &[u32],
_lora_id: u64, _lora_id: u64,
...@@ -395,7 +395,7 @@ pub fn create_stored_block_from_parts( ...@@ -395,7 +395,7 @@ pub fn create_stored_block_from_parts(
} }
pub fn create_stored_blocks( pub fn create_stored_blocks(
kv_block_size: usize, kv_block_size: u32,
token_ids: &[u32], token_ids: &[u32],
num_block_tokens: &[u64], num_block_tokens: &[u64],
block_hashes: &[i64], block_hashes: &[i64],
......
...@@ -92,7 +92,7 @@ pub struct KvScheduler { ...@@ -92,7 +92,7 @@ pub struct KvScheduler {
impl KvScheduler { impl KvScheduler {
pub async fn start( pub async fn start(
ns: Namespace, ns: Namespace,
block_size: usize, block_size: u32,
endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>, endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
...@@ -299,7 +299,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -299,7 +299,7 @@ impl WorkerSelector for DefaultWorkerSelector {
&self, &self,
workers: &ProcessedEndpoints, workers: &ProcessedEndpoints,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: usize, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0); assert!(request.isl_tokens > 0);
...@@ -307,7 +307,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -307,7 +307,7 @@ impl WorkerSelector for DefaultWorkerSelector {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
let request_blocks = request.isl_tokens.div_ceil(block_size); let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
let mut worker_logits = HashMap::new(); let mut worker_logits = HashMap::new();
// Calculate logits for each worker // Calculate logits for each worker
......
...@@ -15,6 +15,7 @@ pub mod common; ...@@ -15,6 +15,7 @@ pub mod common;
pub mod disagg_router; pub mod disagg_router;
pub mod discovery; pub mod discovery;
pub mod engines; pub mod engines;
pub mod entrypoint;
pub mod gguf; pub mod gguf;
pub mod http; pub mod http;
pub mod hub; pub mod hub;
......
...@@ -5,6 +5,9 @@ use std::fs; ...@@ -5,6 +5,9 @@ use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::Context as _;
use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, Endpoint}, component::{Component, Endpoint},
...@@ -12,8 +15,10 @@ use dynamo_runtime::{ ...@@ -12,8 +15,10 @@ use dynamo_runtime::{
}; };
use crate::discovery::ModelEntry; use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType; use crate::model_type::ModelType;
use crate::request_template::RequestTemplate;
mod network_name; mod network_name;
pub use network_name::ModelNetworkName; pub use network_name::ModelNetworkName;
...@@ -25,58 +30,85 @@ const HF_SCHEME: &str = "hf://"; ...@@ -25,58 +30,85 @@ const HF_SCHEME: &str = "hf://";
/// is invisible, for example in a text chat. /// is invisible, for example in a text chat.
const DEFAULT_NAME: &str = "dynamo"; const DEFAULT_NAME: &str = "dynamo";
#[derive(Debug, Clone)] /// Engines don't usually provide a default, so we do.
pub struct LocalModel { const DEFAULT_KV_CACHE_BLOCK_SIZE: u32 = 16;
full_path: PathBuf,
card: ModelDeploymentCard, /// We can't have it default to 0, so pick something
const DEFAULT_HTTP_PORT: u16 = 8080;
pub struct LocalModelBuilder {
model_path: Option<PathBuf>,
model_name: Option<String>,
model_config: Option<PathBuf>,
endpoint_id: Option<EndpointId>,
context_length: Option<u32>,
template_file: Option<PathBuf>,
router_config: Option<RouterConfig>,
kv_cache_block_size: u32,
http_port: u16,
} }
impl Default for LocalModel { impl Default for LocalModelBuilder {
fn default() -> Self { fn default() -> Self {
LocalModel { LocalModelBuilder {
full_path: PathBuf::new(), kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE,
card: ModelDeploymentCard::with_name_only(DEFAULT_NAME), http_port: DEFAULT_HTTP_PORT,
model_path: Default::default(),
model_name: Default::default(),
model_config: Default::default(),
endpoint_id: Default::default(),
context_length: Default::default(),
template_file: Default::default(),
router_config: Default::default(),
} }
} }
} }
impl LocalModel { impl LocalModelBuilder {
pub fn with_name_only(name: &str) -> Self { pub fn model_path(&mut self, model_path: Option<PathBuf>) -> &mut Self {
LocalModel { self.model_path = model_path;
card: ModelDeploymentCard::with_name_only(name), self
..Default::default()
} }
pub fn model_name(&mut self, model_name: Option<String>) -> &mut Self {
self.model_name = model_name;
self
} }
pub fn card(&self) -> &ModelDeploymentCard { pub fn model_config(&mut self, model_config: Option<PathBuf>) -> &mut Self {
&self.card self.model_config = model_config;
self
} }
pub fn path(&self) -> &Path { pub fn endpoint_id(&mut self, endpoint_id: EndpointId) -> &mut Self {
&self.full_path self.endpoint_id = Some(endpoint_id);
self
} }
pub fn display_name(&self) -> &str { pub fn context_length(&mut self, context_length: Option<u32>) -> &mut Self {
&self.card.display_name self.context_length = context_length;
self
} }
pub fn service_name(&self) -> &str { /// Passing None resets it to default
&self.card.service_name pub fn kv_cache_block_size(&mut self, kv_cache_block_size: Option<u32>) -> &mut Self {
self.kv_cache_block_size = kv_cache_block_size.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE);
self
} }
pub fn is_gguf(&self) -> bool { pub fn http_port(&mut self, port: u16) -> &mut Self {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension self.http_port = port;
// We will error when we come to parse it self
self.full_path.is_file()
} }
/// Override max number of tokens in context. We usually only do this to limit kv cache allocation. pub fn router_config(&mut self, router_config: RouterConfig) -> &mut Self {
pub fn set_context_length(&mut self, context_length: usize) { self.router_config = Some(router_config);
self.card.context_length = context_length; self
} }
pub fn set_kv_cache_block_size(&mut self, block_size: usize) { pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
self.card.kv_cache_block_size = block_size; self.template_file = template_file;
self
} }
/// Make an LLM ready for use: /// Make an LLM ready for use:
...@@ -88,28 +120,60 @@ impl LocalModel { ...@@ -88,28 +120,60 @@ impl LocalModel {
/// The model name will depend on what "model_path" is: /// The model name will depend on what "model_path" is:
/// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct" /// - A folder: The last part of the folder name: "/data/llms/Qwen2.5-3B-Instruct" -> "Qwen2.5-3B-Instruct"
/// - A file: The GGUF filename: "/data/llms/Qwen2.5-3B-Instruct-Q6_K.gguf" -> "Qwen2.5-3B-Instruct-Q6_K.gguf" /// - A file: The GGUF filename: "/data/llms/Qwen2.5-3B-Instruct-Q6_K.gguf" -> "Qwen2.5-3B-Instruct-Q6_K.gguf"
/// - An HF repo: The HF repo name: "Qwen/Qwen2.5-3B-Instruct" stays the same /// - An HF repo: The HF repo name: "Qwen/Qwen3-0.6B" stays the same
pub async fn prepare( pub async fn build(&mut self) -> anyhow::Result<LocalModel> {
model_path: &str, // Generate an endpoint ID for this model if the user didn't provide one.
override_config: Option<&Path>, // The user only provides one if exposing the model.
override_name: Option<String>, let endpoint_id = self
) -> anyhow::Result<LocalModel> { .endpoint_id
// Name it .take()
.unwrap_or_else(|| internal_endpoint("local_model"));
let template = self
.template_file
.as_deref()
.map(RequestTemplate::load)
.transpose()?;
// echo_full engine doesn't need a path. It's an edge case, move it out of the way.
if self.model_path.is_none() {
return Ok(LocalModel {
card: ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
),
full_path: PathBuf::new(),
endpoint_id,
template,
http_port: self.http_port,
// We always have one. The Option is so we can take it.
router_config: self
.router_config
.take()
.expect("unreachable, RouterConfig missing"),
});
}
// Main logic. We are running a model.
let model_path = self.model_path.take().unwrap();
let model_path = model_path.to_str().context("Invalid UTF-8 in model path")?;
// Check for hf:// prefix first, in case we really want an HF repo but it conflicts // Check for hf:// prefix first, in case we really want an HF repo but it conflicts
// with a relative path. // with a relative path.
let is_hf_repo = let is_hf_repo =
model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false); model_path.starts_with(HF_SCHEME) || !fs::exists(model_path).unwrap_or(false);
let relative_path = model_path.trim_start_matches(HF_SCHEME); let relative_path = model_path.trim_start_matches(HF_SCHEME);
let full_path = if is_hf_repo { let full_path = if is_hf_repo {
// HF download if necessary // HF download if necessary
super::hub::from_hf(relative_path).await? super::hub::from_hf(relative_path).await?
} else { } else {
fs::canonicalize(relative_path)? fs::canonicalize(relative_path)?
}; };
// --model-config takes precedence over --model-path
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = ModelDeploymentCard::load(&model_config_path).await?;
let model_name = override_name.unwrap_or_else(|| { // Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| {
if is_hf_repo { if is_hf_repo {
// HF repos use their full name ("org/name") not the folder name // HF repos use their full name ("org/name") not the folder name
relative_path.to_string() relative_path.to_string()
...@@ -124,15 +188,83 @@ impl LocalModel { ...@@ -124,15 +188,83 @@ impl LocalModel {
}) })
} }
}); });
card.set_name(&model_name);
// Load the ModelDeploymentCard card.kv_cache_block_size = self.kv_cache_block_size;
// --model-config takes precedence over --model-path // Override max number of tokens in context. We usually only do this to limit kv cache allocation.
let model_config_path = override_config.unwrap_or(&full_path); if let Some(context_length) = self.context_length {
let mut card = ModelDeploymentCard::load(&model_config_path).await?; card.context_length = context_length;
card.set_name(&model_name); }
Ok(LocalModel { full_path, card }) Ok(LocalModel {
card,
full_path,
endpoint_id,
template,
http_port: self.http_port,
router_config: self
.router_config
.take()
.expect("unreachable, RouterConfig missing"),
})
}
}
#[derive(Debug, Clone)]
pub struct LocalModel {
full_path: PathBuf,
card: ModelDeploymentCard,
endpoint_id: EndpointId,
template: Option<RequestTemplate>,
http_port: u16, // Only used if input is HTTP server
router_config: RouterConfig,
}
impl LocalModel {
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn path(&self) -> &Path {
&self.full_path
}
pub fn display_name(&self) -> &str {
&self.card.display_name
}
pub fn service_name(&self) -> &str {
&self.card.service_name
}
pub fn request_template(&self) -> Option<RequestTemplate> {
self.template.clone()
}
pub fn http_port(&self) -> u16 {
self.http_port
}
pub fn router_config(&self) -> &RouterConfig {
&self.router_config
}
pub fn is_gguf(&self) -> bool {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension
// We will error when we come to parse it
self.full_path.is_file()
}
/// An endpoint to identify this model by.
pub fn endpoint_id(&self) -> &EndpointId {
&self.endpoint_id
}
/// Drop the LocalModel returning it's ModelDeploymentCard.
/// For the case where we only need the card and don't want to clone it.
pub fn into_card(self) -> ModelDeploymentCard {
self.card
} }
/// Attach this model the endpoint. This registers it on the network /// Attach this model the endpoint. This registers it on the network
...@@ -202,3 +334,13 @@ impl LocalModel { ...@@ -202,3 +334,13 @@ impl LocalModel {
Ok(()) Ok(())
} }
} }
/// A random endpoint to use for internal communication
/// We can't hard code because we may be running several on the same machine (GPUs 0-3 and 4-7)
fn internal_endpoint(engine: &str) -> EndpointId {
EndpointId {
namespace: Slug::slugify(&uuid::Uuid::new_v4().to_string()).to_string(),
component: engine.to_string(),
name: "generate".to_string(),
}
}
...@@ -57,7 +57,7 @@ pub struct KvManager { ...@@ -57,7 +57,7 @@ pub struct KvManager {
max_capacity: usize, max_capacity: usize,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: u32,
active_blocks: HashMap<UniqueBlock, usize>, active_blocks: HashMap<UniqueBlock, usize>,
...@@ -67,7 +67,7 @@ pub struct KvManager { ...@@ -67,7 +67,7 @@ pub struct KvManager {
} }
impl KvManager { impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self { pub fn new(max_capacity: usize, block_size: u32) -> Self {
let active_blocks = HashMap::new(); let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default(); let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new(); let all_blocks = HashSet::new();
...@@ -245,7 +245,7 @@ impl KvManager { ...@@ -245,7 +245,7 @@ impl KvManager {
let overlap_blocks = unique_blocks.len() - new_blocks; let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens // Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size; let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize);
// // Print the full equation with actual values substituted // // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)", // println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
...@@ -261,7 +261,7 @@ impl KvManager { ...@@ -261,7 +261,7 @@ impl KvManager {
// Calculate prefill compute // Calculate prefill compute
let prefill_compute = let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64; new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64;
Some(PrefillCost { Some(PrefillCost {
new_tokens, new_tokens,
......
...@@ -193,7 +193,7 @@ impl Scheduler { ...@@ -193,7 +193,7 @@ impl Scheduler {
pub fn new( pub fn new(
kv_capacity: usize, kv_capacity: usize,
watermark: f64, watermark: f64,
block_size: usize, block_size: u32,
chunk_size: Option<usize>, chunk_size: Option<usize>,
output_tx: Option<mpsc::Sender<Uuid>>, output_tx: Option<mpsc::Sender<Uuid>>,
cancellation_token: Option<CancellationToken>, cancellation_token: Option<CancellationToken>,
...@@ -272,7 +272,7 @@ impl Scheduler { ...@@ -272,7 +272,7 @@ impl Scheduler {
let mut kv_manager_guard = kv_manager_clone.lock().await; let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding (assumed memory bound on KV cache) // Base time needed for decoding (assumed memory bound on KV cache)
let active_tokens = kv_manager_guard.num_active_blocks() * block_size; let active_tokens = kv_manager_guard.num_active_blocks() * (block_size as usize);
// TODO: 2 is a dummy / magic scaling factor // TODO: 2 is a dummy / magic scaling factor
let mut generation_time = Duration::from_micros((active_tokens / 2) as u64); let mut generation_time = Duration::from_micros((active_tokens / 2) as u64);
...@@ -406,7 +406,7 @@ impl Scheduler { ...@@ -406,7 +406,7 @@ impl Scheduler {
} }
/// Convert a Request to an ActiveSequence /// Convert a Request to an ActiveSequence
fn get_active_sequence(request: Request, block_size: usize, chunk_size: usize) -> ActiveSequence { fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> ActiveSequence {
if let Request::Active(active_seq) = request { if let Request::Active(active_seq) = request {
return active_seq; return active_seq;
} }
...@@ -475,7 +475,7 @@ mod tests { ...@@ -475,7 +475,7 @@ mod tests {
let kv_capacity: usize = 500; let kv_capacity: usize = 500;
let watermark: f64 = 0.01; // 1% watermark let watermark: f64 = 0.01; // 1% watermark
let block_size: usize = 64; let block_size: u32 = 64;
let chunk_size: usize = 256; let chunk_size: usize = 256;
let num_requests: usize = 100; let num_requests: usize = 100;
let input_len: usize = 1000; let input_len: usize = 1000;
......
...@@ -23,7 +23,7 @@ use uuid; ...@@ -23,7 +23,7 @@ use uuid;
fn create_unique_blocks_from_sequence( fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence, tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>, uuid: Option<uuid::Uuid>,
block_size: usize, block_size: u32,
) -> Vec<UniqueBlock> { ) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks() .blocks()
...@@ -32,7 +32,7 @@ fn create_unique_blocks_from_sequence( ...@@ -32,7 +32,7 @@ fn create_unique_blocks_from_sequence(
.collect(); .collect();
// Only push the partial block if tokens count isn't a multiple of block_size // Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 { if tokens.total_tokens() % (block_size as usize) != 0 {
unique_blocks.push(match uuid { unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid), Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(), None => UniqueBlock::default(),
...@@ -50,7 +50,7 @@ pub struct ActiveSequence { ...@@ -50,7 +50,7 @@ pub struct ActiveSequence {
tokens: TokenBlockSequence, tokens: TokenBlockSequence,
#[getter(copy)] #[getter(copy)]
block_size: usize, block_size: u32,
#[getter(copy)] #[getter(copy)]
chunk_size: usize, // TODO: not actually used chunk_size: usize, // TODO: not actually used
...@@ -72,7 +72,7 @@ impl ActiveSequence { ...@@ -72,7 +72,7 @@ impl ActiveSequence {
pub fn new( pub fn new(
tokens: Vec<u32>, tokens: Vec<u32>,
max_output_tokens: usize, max_output_tokens: usize,
block_size: Option<usize>, block_size: Option<u32>,
chunk_size: Option<usize>, chunk_size: Option<usize>,
) -> Self { ) -> Self {
let block_size = block_size.unwrap_or(64); let block_size = block_size.unwrap_or(64);
...@@ -96,8 +96,8 @@ impl ActiveSequence { ...@@ -96,8 +96,8 @@ impl ActiveSequence {
} }
} }
pub fn extra_tokens(&self) -> usize { pub fn extra_tokens(&self) -> u32 {
self.len() % self.block_size (self.len() % self.block_size as usize) as u32
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
...@@ -112,7 +112,7 @@ impl ActiveSequence { ...@@ -112,7 +112,7 @@ impl ActiveSequence {
pub fn new_with_signal( pub fn new_with_signal(
tokens: Vec<u32>, tokens: Vec<u32>,
max_output_tokens: usize, max_output_tokens: usize,
block_size: Option<usize>, block_size: Option<u32>,
chunk_size: Option<usize>, chunk_size: Option<usize>,
) -> (Self, Option<MoveBlock>) { ) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size); let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
...@@ -125,7 +125,7 @@ impl ActiveSequence { ...@@ -125,7 +125,7 @@ impl ActiveSequence {
self.tokens.append(token).expect("Token push failed."); self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1; self.generated_tokens += 1;
if self.len() % self.block_size != 1 { if self.len() % (self.block_size as usize) != 1 {
return None; return None;
} }
...@@ -223,7 +223,7 @@ impl ActiveSequence { ...@@ -223,7 +223,7 @@ impl ActiveSequence {
self.generated_tokens = self.generated_tokens.saturating_sub(1); self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block // Reverts to the last full block
if self.tokens.total_tokens() % self.block_size == 0 { if self.tokens.total_tokens() % (self.block_size as usize) == 0 {
self.unique_blocks.pop(); self.unique_blocks.pop();
} }
} }
...@@ -285,7 +285,7 @@ mod tests { ...@@ -285,7 +285,7 @@ mod tests {
// Verify state after pushing tokens // Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17); assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % seq1.block_size(), 1); assert_eq!(seq1.len() % (seq1.block_size() as usize), 1);
// Create another sequence with block size 16 initialized with tokens [0..17] // Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect(); let extended_tokens: Vec<u32> = (0..16).collect();
...@@ -335,12 +335,12 @@ mod tests { ...@@ -335,12 +335,12 @@ mod tests {
"seq2 should have exactly 3 blocks" "seq2 should have exactly 3 blocks"
); );
assert_eq!( assert_eq!(
seq1.len() % seq1.block_size(), seq1.len() % (seq1.block_size() as usize),
1, 1,
"seq1 should have 1 partial token" "seq1 should have 1 partial token"
); );
assert_eq!( assert_eq!(
seq2.len() % seq2.block_size(), seq2.len() % (seq2.block_size() as usize),
1, 1,
"seq2 should have 1 partial token" "seq2 should have 1 partial token"
); );
......
...@@ -76,7 +76,7 @@ impl ModelDeploymentCard { ...@@ -76,7 +76,7 @@ impl ModelDeploymentCard {
let content = super::model::load_gguf(gguf_file)?; let content = super::model::load_gguf(gguf_file)?;
let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())] let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
.to_u32() .to_u32()
.unwrap_or(0) as usize; .unwrap_or(0);
tracing::debug!(context_length, "Loaded context length from GGUF"); tracing::debug!(context_length, "Loaded context length from GGUF");
Ok(Self { Ok(Self {
......
...@@ -117,11 +117,11 @@ pub struct ModelDeploymentCard { ...@@ -117,11 +117,11 @@ pub struct ModelDeploymentCard {
pub revision: u64, pub revision: u64,
/// Max context (in number of tokens) this model can handle /// Max context (in number of tokens) this model can handle
pub context_length: usize, pub context_length: u32,
/// Size of a KV cache block - vllm only currently /// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router. /// Passed to the engine and the KV router.
pub kv_cache_block_size: usize, pub kv_cache_block_size: u32,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment