Unverified Commit 183f2b32 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Allow setting KV cache block size (#1175)

Example:
```
dynamo-run out=<engine> <model> --kv-cache-block-size 64
```

In a distributed system this goes on the worker node and is propagated to ingress via the model deployment card.

Previously hard coded to 16, which is now the default.

- Load context_length from model. Closes #1172
- Store context length and KV cache block size in Model Deployment Card #1170
parent 7860861f
......@@ -592,6 +592,11 @@ The `model_type` can be:
- ModelType.Chat. Your `generate` method receives a `request` and must return a response dict of type [OpenAI Chat Completion](https://platform.openai.com/docs/api-reference/chat). Your engine handles pre-processing.
- ModelType.Completion. Your `generate` method receives a `request` and must return a response dict of the older [Completions](https://platform.openai.com/docs/api-reference/completions). Your engine handles pre-processing.
`register_llm` can also take the following kwargs:
- `model_name`: The name to call the model. Your incoming HTTP requests model name must match this. Defaults to the hugging face repo name, the folder name, or the GGUF file name.
- `context_length`: Max model length in tokens. Defaults to the model's set max. Only set this if you need to reduce KV cache allocation to fit into VRAM.
- `kv_cache_block_size`: Size of a KV block for the engine, in tokens. Defaults to 16.
Here are some example engines:
- Backend:
......
......@@ -111,6 +111,10 @@ pub struct Flags {
#[arg(long)]
pub context_length: Option<usize>,
/// KV cache block size (vllm only)
#[arg(long)]
pub kv_cache_block_size: Option<usize>,
/// Additional engine-specific arguments from a JSON file.
/// Contains a mapping of parameter names to values.
#[arg(long)]
......
......@@ -25,6 +25,9 @@ const PYTHON_STR_SCHEME: &str = "pystr:";
/// Where we will attach the vllm/sglang subprocess. Invisible to users.
pub const INTERNAL_ENDPOINT: &str = "dyn://dynamo.internal.worker";
/// Default size of a KV cache block. Override with --kv-cache-block-size
const DEFAULT_KV_CACHE_BLOCK_SIZE: usize = 16;
pub enum EngineConfig {
/// Remote networked engines
Dynamic,
......@@ -84,7 +87,17 @@ pub async fn run(
}
}
};
local_model.context_length = flags.context_length;
// Only set if user provides. Usually loaded from tokenizer_config.json
if let Some(context_length) = flags.context_length {
local_model.set_context_length(context_length);
}
// Always set, there is no engine provided default
local_model.set_kv_cache_block_size(
flags
.kv_cache_block_size
.unwrap_or(DEFAULT_KV_CACHE_BLOCK_SIZE),
);
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
......@@ -101,7 +114,16 @@ pub async fn run(
// Create the engine matching `out`
let engine_config = match out_opt {
Output::Dynamic => EngineConfig::Dynamic,
Output::Dynamic => {
// Sanity check - TODO probably make a general sanity check at start of method
if flags.context_length.is_some() {
anyhow::bail!("'--content-length' flag should only be used on the worker node, not on the ingress");
}
if flags.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
EngineConfig::Dynamic
}
Output::EchoFull => EngineConfig::StaticFull {
model: Box::new(local_model),
engine: dynamo_llm::engines::make_engine_full(),
......@@ -145,19 +167,12 @@ pub async fn run(
subprocess::sglang::PY,
&local_model,
&endpoint,
flags.tensor_parallel_size,
flags.context_length,
if flags.base_gpu_id == 0 {
None
} else {
Some(flags.base_gpu_id)
},
flags.clone(),
if flags.num_nodes <= 1 {
None
} else {
Some(multi_node_conf)
},
flags.extra_engine_args.as_deref(),
)
.await
{
......@@ -190,11 +205,8 @@ pub async fn run(
subprocess::vllm::PY,
&local_model,
&endpoint,
flags.tensor_parallel_size,
flags.context_length,
None, // base_gpu_id. vllm uses CUDA_VISIBLE_DEVICES instead
flags.clone(),
None, // multi-node config. vllm uses `ray`, see guide
flags.extra_engine_args.as_deref(),
)
.await
{
......
......@@ -30,7 +30,7 @@ Example:
- OR: ./dynamo-run /data/models/Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#;
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv]";
const USAGE: &str = "USAGE: dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--kv-cache-block-size=16] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv]";
fn main() -> anyhow::Result<()> {
// Set log level based on verbosity flag
......
......@@ -3,7 +3,6 @@
use std::borrow::Cow;
use std::io::Write;
use std::path::Path;
use std::process::Stdio;
use std::sync::LazyLock;
......@@ -18,8 +17,6 @@ use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang;
pub mod vllm;
// TODO: I guess make a config object?
#[allow(clippy::too_many_arguments)]
pub async fn start(
// The Python code to run
py_script: &'static str,
......@@ -27,23 +24,17 @@ pub async fn start(
local_model: &LocalModel,
// Endpoint to connect the subprocess over etcd/nats
endpoint: &EndpointId,
// How many GPUs to use
tensor_parallel_size: u32,
// Max context length to allow
context_length: Option<usize>,
// sglang which GPU to start from, on a multi-GPU system
// vllm uses CUDA_VISIBLE_DEVICES
base_gpu_id: Option<u32>,
// Command line flags for user overrides
flags: super::Flags,
// sglang multi-node config. vllm uses `ray` externally
multi_node_config: Option<MultiNodeConfig>,
// Path to a JSON file containing extra arguments to the backend engine
extra_engine_args: Option<&Path>,
) -> anyhow::Result<(tempfile::TempPath, tokio::process::Child)> {
let mut tmp = tempfile::NamedTempFile::new()?;
// Writes on Linux don't block
tmp.write_all(py_script.as_bytes())?;
let script_path = tmp.into_temp_path();
let card = local_model.card();
let mut args = vec![
script_path.to_string_lossy().to_string(),
"--endpoint".to_string(),
......@@ -53,19 +44,17 @@ pub async fn start(
"--model-name".to_string(),
local_model.display_name().to_string(),
"--tensor-parallel-size".to_string(),
tensor_parallel_size.to_string(),
flags.tensor_parallel_size.to_string(),
"--kv-block-size".to_string(),
dynamo_llm::DEFAULT_KV_BLOCK_SIZE.to_string(),
card.kv_cache_block_size.to_string(),
"--context-length".to_string(),
card.context_length.to_string(),
];
if let Some(context_length) = context_length {
args.push("--context-length".to_string());
args.push(context_length.to_string());
}
// sglang only
if let Some(base_gpu_id) = base_gpu_id {
// vllm uses CUDA_VISIBLE_DEVICES
if flags.base_gpu_id != 0 {
args.push("--base-gpu-id".to_string());
args.push(base_gpu_id.to_string());
args.push(flags.base_gpu_id.to_string());
}
// sglang only
if let Some(multi_node_config) = multi_node_config {
......@@ -76,7 +65,7 @@ pub async fn start(
args.push("--dist-init-addr".to_string());
args.push(multi_node_config.leader_addr);
}
if let Some(extra_engine_args) = extra_engine_args {
if let Some(extra_engine_args) = flags.extra_engine_args {
args.push("--extra-engine-args".to_string());
args.push(extra_engine_args.to_string_lossy().to_string());
}
......
......@@ -91,8 +91,11 @@ async def init(runtime: DistributedRuntime, config: Config):
"skip_tokenizer_init": True,
"tp_size": config.tensor_parallel_size,
"base_gpu_id": config.base_gpu_id,
"page_size": config.kv_block_size,
}
if config.kv_block_size:
arg_map["page_size"] = config.kv_block_size
if config.context_length:
arg_map["context_length"] = config.context_length
......
......@@ -154,10 +154,12 @@ async def init(runtime: DistributedRuntime, config: Config):
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
"block_size": config.kv_block_size,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
}
if config.kv_block_size:
arg_map["block_size"] = config.kv_block_size
if config.context_length:
# Usually we want it to default to the max (from tokenizer_config.json)
arg_map["max_model_len"] = config.context_length
......
......@@ -92,13 +92,15 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None))]
fn register_llm<'p>(
py: Python<'p>,
model_type: ModelType,
endpoint: Endpoint,
model_path: &str,
model_name: Option<&str>,
context_length: Option<usize>,
kv_cache_block_size: Option<usize>,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -115,6 +117,12 @@ fn register_llm<'p>(
llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name)
.await
.map_err(to_pyerr)?;
if let Some(context_length) = context_length {
local_model.set_context_length(context_length);
}
if let Some(kv_cache_block_size) = kv_cache_block_size {
local_model.set_kv_cache_block_size(kv_cache_block_size);
}
// Advertise ourself on etcd so ingress can find us
local_model
......
......@@ -29,10 +29,6 @@ use dynamo_llm::{backend::ExecutionContext, local_model::LocalModel};
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// If the user does not provide a context length limit default to this.
/// TODO: This should come from GGUF key {model}.context_length
const CONTEXT_LENGTH: u32 = 32768;
static LLAMA_BACKEND: tokio::sync::OnceCell<LlamaBackend> = tokio::sync::OnceCell::const_new();
pub(crate) static LLAMA_MODEL: tokio::sync::OnceCell<LlamaModel> =
tokio::sync::OnceCell::const_new();
......@@ -75,13 +71,13 @@ impl LlamacppEngine {
LLAMA_MODEL.set(model)?;
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
let n_ctx = NonZeroU32::new(
model_config
.context_length
.map(|n| n as u32)
.unwrap_or(CONTEXT_LENGTH),
);
let llama_ctx_params = LlamaContextParams::default().with_n_ctx(n_ctx);
let llama_ctx_params = if model_config.card().context_length > 0 {
let n_ctx = NonZeroU32::new(model_config.card().context_length as u32);
LlamaContextParams::default().with_n_ctx(n_ctx)
} else {
// Context length defaults to 512 currently
LlamaContextParams::default()
};
for (i, ctx_holder) in LLAMA_CONTEXTS.iter().enumerate().take(NUM_CONTEXTS) {
let llama_ctx = LLAMA_MODEL
.get()
......
......@@ -127,10 +127,11 @@ impl MistralRsEngine {
.build(None)?
};
// TODO: The default max seq len should come from the model not be hard coded
let max_seq_len = model
.context_length
.unwrap_or(AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN);
let mut max_seq_len = model.card().context_length;
if max_seq_len == 0 {
tracing::info!("context_length is 0. Probably error reading from model.");
max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
}
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
......
......@@ -178,11 +178,13 @@ impl ModelManager {
&self,
model_name: &str,
component: &Component,
kv_cache_block_size: usize,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
return Ok(kv_chooser);
}
self.create_kv_chooser(model_name, component).await
self.create_kv_chooser(model_name, component, kv_cache_block_size)
.await
}
fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
......@@ -194,14 +196,10 @@ impl ModelManager {
&self,
model_name: &str,
component: &Component,
kv_cache_block_size: usize,
) -> anyhow::Result<Arc<KvRouter>> {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
let chooser = KvRouter::new(component.clone(), kv_cache_block_size, Some(selector)).await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
......
......@@ -208,7 +208,7 @@ impl ModelWatcher {
RouterMode::KV => {
let chooser = self
.manager
.kv_chooser_for(&model_entry.name, &component)
.kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
.await?;
let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
......@@ -243,7 +243,7 @@ impl ModelWatcher {
RouterMode::KV => {
let chooser = self
.manager
.kv_chooser_for(&model_entry.name, &component)
.kv_chooser_for(&model_entry.name, &component, card.kv_cache_block_size)
.await?;
let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
......@@ -50,9 +38,6 @@ use crate::{
use dynamo_runtime::traits::events::EventSubscriber;
// TODO: Allow user to change
pub const DEFAULT_KV_BLOCK_SIZE: usize = 16;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
pub const KV_EVENT_SUBJECT: &str = "kv_events";
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Dynamo LLM
//!
......@@ -28,7 +16,6 @@ pub mod http;
pub mod hub;
pub mod key_value_store;
pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod local_model;
pub mod mocker;
pub mod model_card;
......
......@@ -27,10 +27,6 @@ const DEFAULT_NAME: &str = "dynamo";
pub struct LocalModel {
full_path: PathBuf,
card: ModelDeploymentCard,
/// The max context the engine will allow us sending to the model.
/// If not set this defaults to the engine's configured maximum.
pub context_length: Option<usize>,
}
impl Default for LocalModel {
......@@ -38,7 +34,6 @@ impl Default for LocalModel {
LocalModel {
full_path: PathBuf::new(),
card: ModelDeploymentCard::with_name_only(DEFAULT_NAME),
context_length: None,
}
}
}
......@@ -67,6 +62,15 @@ impl LocalModel {
&self.card.service_name
}
/// Override max number of tokens in context. We usually only do this to limit kv cache allocation.
pub fn set_context_length(&mut self, context_length: usize) {
self.card.context_length = context_length;
}
pub fn set_kv_cache_block_size(&mut self, block_size: usize) {
self.card.kv_cache_block_size = block_size;
}
/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
......@@ -120,11 +124,7 @@ impl LocalModel {
let mut card = ModelDeploymentCard::load(&model_config_path).await?;
card.set_name(&model_name);
Ok(LocalModel {
full_path,
card,
..Default::default()
})
Ok(LocalModel { full_path, card })
}
/// Attach this model the endpoint. This registers it on the network
......
......@@ -17,8 +17,9 @@ use std::collections::HashMap;
use crate::model_card::model::ModelDeploymentCard;
use anyhow::{Context, Result};
use std::fs;
use std::path::Path;
use std::fs::{self, File};
use std::io::BufReader;
use std::path::{Path, PathBuf};
use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind};
......@@ -84,6 +85,14 @@ impl ModelDeploymentCard {
gguf_file.display()
);
};
// TODO: we do this in HFConfig also, unify
let content = super::model::load_gguf(gguf_file)?;
let context_length = content.get_metadata()[&format!("{}.context_length", content.arch())]
.to_u32()
.unwrap_or(0) as usize;
tracing::debug!(context_length, "Loaded context length from GGUF");
Ok(Self {
display_name: model_name.to_string(),
service_name: model_name.to_string(),
......@@ -93,11 +102,11 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
context_length,
kv_cache_block_size: 0,
})
}
/// TODO: This will be implemented after nova-hub is integrated with the model-card
/// TODO: Attempt to auto-detect model type and construct an MDC from a NGC repo
#[allow(dead_code)]
async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> {
Err(anyhow::anyhow!(
......@@ -106,6 +115,16 @@ impl ModelDeploymentCard {
}
async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
let context_length = file_json_field(
&Path::join(&PathBuf::from(repo_id), "tokenizer_config.json"),
"model_max_length",
)
.unwrap_or(0);
tracing::trace!(
context_length,
"Loaded context length (model_max_length) from tokenizer_config.json"
);
Ok(Self {
display_name: model_name.to_string(),
service_name: model_name.to_string(),
......@@ -115,6 +134,8 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
context_length,
kv_cache_block_size: 0, // set later
})
}
}
......@@ -221,3 +242,58 @@ fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
}
Ok(())
}
/// Reads a JSON file, extracts a specific field, and deserializes it into type T.
///
/// # Arguments
///
/// * `json_file_path`: Path to the JSON file.
/// * `field_name`: The name of the field to extract from the JSON map.
///
/// # Returns
///
/// A `Result` containing the deserialized value of type `T` if successful,
/// or an `anyhow::Error` if any step fails (file I/O, JSON parsing, field not found,
/// or deserialization to `T` fails).
///
/// # Type Parameters
///
/// * `T`: The expected type of the field's value. `T` must implement `serde::de::DeserializeOwned`.
fn file_json_field<T: serde::de::DeserializeOwned>(
json_file_path: &Path,
field_name: &str,
) -> anyhow::Result<T> {
// 1. Open the file
let file = File::open(json_file_path)
.with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
let reader = BufReader::new(file);
// 2. Parse the JSON file into a generic serde_json::Value
// We parse into `serde_json::Value` first because we need to look up a specific field.
// If we tried to deserialize directly into `T`, `T` would need to represent the whole JSON structure.
let json_data: serde_json::Value = serde_json::from_reader(reader)
.with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;
// 3. Ensure the root of the JSON is an object (map)
let map = json_data.as_object().ok_or_else(|| {
anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
})?;
// 4. Get the specific field's value
let field_value = map.get(field_name).ok_or_else(|| {
anyhow::anyhow!(
"Field '{}' not found in JSON file: {:?}",
field_name,
json_file_path
)
})?;
// 5. Deserialize the field's value into the target type T
// We need to clone `field_value` because `from_value` consumes its input.
serde_json::from_value(field_value.clone()).with_context(|| {
format!(
"Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
field_name, field_value, json_file_path
)
})
}
......@@ -123,6 +123,13 @@ pub struct ModelDeploymentCard {
/// Incrementing count of how many times we published this card
#[serde(default, skip_serializing)]
pub revision: u64,
/// Max context (in number of tokens) this model can handle
pub context_length: usize,
/// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router.
pub kv_cache_block_size: usize,
}
impl ModelDeploymentCard {
......@@ -486,7 +493,7 @@ impl TokenizerKind {
}
}
fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
pub(crate) fn load_gguf(gguf_file: &Path) -> anyhow::Result<Content> {
let filename = gguf_file.display().to_string();
let mut f = File::open(gguf_file).with_context(|| filename.clone())?;
// vec because GGUF can be split into multiple files (shards)
......
......@@ -19,7 +19,7 @@
"single_word": false
},
"legacy": false,
"model_max_length": 1000000000000000019884624838656,
"model_max_length": 2048,
"pad_token": null,
"padding_side": "right",
"sp_model_kwargs": {},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment