Unverified Commit bbe82f18 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove dynamo-run and mistral-rs engine (#6203)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 2c747d64
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _;
use dynamo_llm::entrypoint::EngineConfig;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_runtime::distributed::{DistributedConfig, RequestPlaneMode};
use dynamo_runtime::storage::kv;
use dynamo_runtime::transports::nats;
use dynamo_runtime::{DistributedRuntime, Runtime};
mod flags;
pub use flags::Flags;
mod opt;
pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::Output;
pub async fn run(
runtime: Runtime,
in_opt: Input,
out_opt: Option<Output>,
mut flags: Flags,
) -> anyhow::Result<()> {
//
// Download
//
let maybe_remote_repo = flags
.model_path_pos
.clone()
.or_else(|| flags.model_path_flag.clone());
// Preserve the original model identifier before downloading (for default model name)
let original_model_identifier = maybe_remote_repo.as_ref().map(|p| p.display().to_string());
let model_path = match maybe_remote_repo {
None => None,
Some(p) if p.exists() => {
// Already a local path
Some(p)
}
Some(p) => {
// model_path might be an HF repo, not a local path. Resolve it by downloading.
// Mocker only needs tokenizer, not weights
let ignore_weights = matches!(out_opt, Some(Output::Mocker));
Some(LocalModel::fetch(&p.display().to_string(), ignore_weights).await?)
}
};
//
// Configure
//
let mut builder = LocalModelBuilder::default();
builder
.model_name(flags.model_name.clone().or(original_model_identifier))
.kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length)
.http_port(flags.http_port)
.tls_cert_path(flags.tls_cert_path.take())
.tls_key_path(flags.tls_key_path.take())
.router_config(Some(flags.router_config()))
.migration_limit(flags.migration_limit)
.request_template(flags.request_template.clone())
.is_mocker(matches!(out_opt, Some(Output::Mocker)));
// Only the worker has a model path
if let Some(model_path) = model_path {
builder.model_path(model_path);
}
// TODO: old, address this later:
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
if let Input::Endpoint(path) = &in_opt {
builder.endpoint_id(Some(path.parse().with_context(|| path.clone())?));
}
let dst_config = if is_process_local(&in_opt, &out_opt) {
// We are both the frontend and backend, no networking
DistributedConfig::process_local()
} else {
// Normal case
let selected_store: kv::Selector = flags.store_kv.parse()?;
let request_plane: RequestPlaneMode = flags.request_plane.parse()?;
DistributedConfig {
store_backend: selected_store,
// We only need NATS here to monitor it's metrics, so only if it's our request plane.
nats_config: if request_plane.is_nats() {
Some(nats::ClientOptions::default())
} else {
None
},
request_plane,
}
};
let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
let local_model = builder.build().await?;
//
// Create an engine
//
let out_opt = out_opt.unwrap_or_else(|| default_engine_for(&local_model));
print_cuda(&out_opt);
// Now that we know the output we're targeting, check if we expect it to work
flags.validate(&out_opt)?;
// Make an engine from the local_model, flags and output.
let engine_config = engine_for(
out_opt,
flags.clone(),
local_model,
distributed_runtime.clone(),
)
.await?;
// Run it from an input
dynamo_llm::entrypoint::input::run_input(distributed_runtime, in_opt, engine_config).await?;
Ok(())
}
pub fn is_in_dynamic(in_opt: &Input) -> bool {
matches!(in_opt, Input::Endpoint(_))
}
pub fn is_out_dynamic(out_opt: &Option<Output>) -> bool {
matches!(out_opt, Some(Output::Auto))
}
fn is_process_local(in_opt: &Input, out_opt: &Option<Output>) -> bool {
!is_in_dynamic(in_opt) && !is_out_dynamic(out_opt)
}
/// Create the engine matching `out_opt`
/// Note validation happens in Flags::validate. In here assume everything is going to work.
async fn engine_for(
out_opt: Output,
flags: Flags,
local_model: LocalModel,
drt: DistributedRuntime,
) -> anyhow::Result<EngineConfig> {
match out_opt {
Output::Auto => {
// Auto-discover backends
Ok(EngineConfig::Dynamic {
model: Box::new(local_model),
chat_engine_factory: None,
})
}
Output::Echo => Ok(EngineConfig::InProcessText {
model: Box::new(local_model),
engine: dynamo_llm::engines::make_echo_engine(),
}),
#[cfg(feature = "mistralrs")]
Output::MistralRs => Ok(EngineConfig::InProcessText {
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model),
}),
Output::Mocker => {
let args = flags.mocker_config();
let endpoint = local_model.endpoint_id().clone();
let engine = dynamo_llm::mocker::make_mocker_engine(drt, endpoint, args).await?;
Ok(EngineConfig::InProcessTokens {
engine,
model: Box::new(local_model),
is_prefill: false,
})
}
}
}
/// If the user will benefit from CUDA or Metal, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs needs to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(feature = "mistralrs")]
fn print_cuda(output: &Output) {
// These engines maybe be compiled in, but are they the chosen one?
match output {
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
_ => {
return;
}
}
#[cfg(feature = "cuda")]
{
tracing::info!("CUDA on");
}
#[cfg(feature = "metal")]
{
tracing::info!("Metal on");
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
tracing::info!("CPU mode. Rebuild with `--features cuda|metal` for better performance");
}
#[cfg(not(feature = "mistralrs"))]
fn print_cuda(_output: &Output) {}
fn default_engine_for(_local_model: &LocalModel) -> Output {
safetensors_default()
}
fn safetensors_default() -> Output {
#[cfg(feature = "mistralrs")]
{
Output::MistralRs
}
#[cfg(not(feature = "mistralrs"))]
{
Output::Echo
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::env;
use clap::{CommandFactory as _, Parser};
use dynamo_runtime::config::environment_names::logging as env_logging;
use dynamo_llm::entrypoint::input::Input;
use dynamo_run::Output;
use dynamo_runtime::logging;
const HELP: &str = r#"
dynamo-run is a single binary that wires together the various inputs (http, text, network) and workers (network, engine), that runs the services. It is the simplest way to use dynamo locally.
Verbosity:
- -v enables debug logs
- -vv enables full trace logs
- Default is info level logging
Example:
- cargo build --features cuda -p dynamo-run
- cd target/debug
- ./dynamo-run Qwen/Qwen3-0.6B (OR ./dynamo-run /data/hf-checkouts/Qwen3-0.6B)
See `docs/guides/dynamo_run.md` in the repo for full details.
"#;
const USAGE: &str = "USAGE: dynamo-run in=[http|grpc|text|dyn://<path>|batch:<folder>] out=ENGINE_LIST|auto|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--context-length=N] [--kv-cache-block-size=16] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--router-temperature=0.0] [--use-kv-events] [--max-num-batched-tokens=1.0] [--migration-limit=0] [--verbosity (-v|-vv)]";
fn main() -> anyhow::Result<()> {
// Set log level based on verbosity flag
let log_level = match dynamo_run::Flags::try_parse() {
Ok(flags) => match flags.verbosity {
0 => "info",
1 => "debug",
2 => "trace",
_ => {
return Err(anyhow::anyhow!(
"Invalid verbosity level. Valid values are v (debug) or vv (trace)"
));
}
},
Err(_) => "info",
};
if log_level != "info" {
unsafe { std::env::set_var(env_logging::DYN_LOG, log_level) };
}
logging::init();
// max_worker_threads and max_blocking_threads from env vars or config file.
let rt_config = dynamo_runtime::RuntimeConfig::from_settings()?;
tracing::debug!("Runtime config: {rt_config}");
// One per process. Wraps a Runtime with holds one or two tokio runtimes.
let worker = dynamo_runtime::Worker::from_config(rt_config)?;
worker.execute(wrapper)
}
async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
let mut in_opt = None;
let mut out_opt = None;
let args: Vec<String> = env::args().skip(1).collect();
if args.is_empty()
|| args[0] == "-h"
|| args[0] == "--help"
|| (args.iter().all(|arg| arg == "-v" || arg == "-vv"))
{
let engine_list = Output::available_engines().join("|");
let usage = USAGE.replace("ENGINE_LIST", &engine_list);
println!("{usage}");
println!("{HELP}");
dynamo_run::Flags::command().print_long_help().unwrap();
return Ok(());
} else if args[0] == "--version" {
if let Some(describe) = option_env!("VERGEN_GIT_DESCRIBE") {
println!("dynamo-run {}", describe);
} else {
println!("Version not available (git describe not available)");
}
return Ok(());
}
for arg in env::args().skip(1).take(2) {
let Some((in_out, val)) = arg.split_once('=') else {
// Probably we're defaulting in and/or out, and this is a flag
continue;
};
match in_out {
"in" => {
in_opt = Some(val.try_into()?);
}
"out" => {
if val == "sglang" || val == "trtllm" || val == "vllm" {
tracing::error!(
"To run the {val} engine please use the Python interface, see root README or look in directory `examples/backends/`."
);
std::process::exit(1);
}
out_opt = Some(val.try_into()?);
}
_ => {
anyhow::bail!("Invalid argument, must start with 'in' or 'out. {USAGE}");
}
}
}
let mut non_flag_params = 1; // binary name
let in_opt = match in_opt {
Some(x) => {
non_flag_params += 1;
x
}
None => Input::default(),
};
if out_opt.is_some() {
non_flag_params += 1;
}
// Clap skips the first argument expecting it to be the binary name, so add it back
// Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag.
let flags = dynamo_run::Flags::try_parse_from(
["dynamo-run".to_string()]
.into_iter()
.chain(env::args().skip(non_flag_params)),
)?;
if dynamo_run::is_in_dynamic(&in_opt) && dynamo_run::is_out_dynamic(&out_opt) {
anyhow::bail!("Cannot use endpoint for both in and out");
}
dynamo_run::run(runtime, in_opt, out_opt, flags).await
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::fmt;
pub enum Output {
/// Echos the prompt back as the response
Echo,
/// Listen for models on nats/etcd, add/remove dynamically
Auto,
#[cfg(feature = "mistralrs")]
MistralRs,
Mocker,
}
impl TryFrom<&str> for Output {
type Error = anyhow::Error;
fn try_from(s: &str) -> anyhow::Result<Self> {
match s {
#[cfg(feature = "mistralrs")]
"mistralrs" => Ok(Output::MistralRs),
"mocker" => Ok(Output::Mocker),
"echo" | "echo_full" => Ok(Output::Echo),
"dyn" | "auto" => Ok(Output::Auto),
e => Err(anyhow::anyhow!("Invalid out= option '{e}'")),
}
}
}
impl fmt::Display for Output {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self {
#[cfg(feature = "mistralrs")]
Output::MistralRs => "mistralrs",
Output::Mocker => "mocker",
Output::Echo => "echo",
Output::Auto => "auto",
};
write!(f, "{s}")
}
}
impl Output {
#[allow(unused_mut)]
pub fn available_engines() -> Vec<String> {
let mut out = vec!["echo".to_string(), Output::Mocker.to_string()];
#[cfg(feature = "mistralrs")]
{
out.push(Output::MistralRs.to_string());
}
out
}
}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Example cli using the Python bindings, similar to `dynamo-run`.
# Example cli using the Python bindings.
#
# Usage: `python cli.py in=text out=echo <your-model>`.
# `in` can be:
......
......@@ -16,8 +16,7 @@
# Start nats and etcd:
# - nats-server -js
#
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn
# `python server_sglang.py`. Wait for log "Starting endpoint".
import argparse
import asyncio
......
......@@ -17,8 +17,7 @@
# Start nats and etcd:
# - nats-server -js
#
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn
# `python server_sglang.py`. Wait for log "Starting endpoint".
import argparse
import asyncio
......
......@@ -12,8 +12,7 @@
# Start nats and etcd:
# - nats-server -js
#
# Window 1: `python server_vllm.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn
# `python server_vllm.py`. Wait for log "Starting endpoint".
import argparse
import asyncio
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
[package]
name = "dynamo-engine-mistralrs"
version.workspace = true
edition.workspace = true
description.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
keywords.workspace = true
[features]
default = []
cuda = ["mistralrs/cuda"]
metal = ["mistralrs/metal"]
[dependencies]
dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true }
anyhow = { workspace = true }
dynamo-async-openai = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
either = { workspace = true }
indexmap = { version = "2.9.0", features = ["serde"] }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", version = "0.6.0", rev = "2bcf0e9e3" }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::{num::NonZero, sync::Arc};
use async_stream::stream;
use async_trait::async_trait;
use dynamo_async_openai::types::FinishReason;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
PagedCacheType, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig,
StopTokens, TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
};
use tokio::sync::mpsc::channel;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_llm::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse, prompt_to_string},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
};
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
use dynamo_llm::local_model::LocalModel;
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;
/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";
pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
let engine = MistralRsEngine::new(model).await?;
let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
Ok(engine)
}
/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
#[cfg(not(feature = "metal"))]
{
Ok(Device::cuda_if_available(0)?)
}
#[cfg(feature = "metal")]
{
Ok(Device::new_metal(0)?)
}
}
struct MistralRsEngine {
mistralrs: Arc<MistralRs>,
context_length: usize,
display_name: String,
}
impl MistralRsEngine {
async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
let model_path = model.path();
// Name some None's for clarity
let chat_template = None;
let tokenizer_json = None;
let no_kv_cache = false;
let jinja_explicit = None;
let display_name = model.display_name();
let loader = if model_path.is_file() {
// Load from a GGUF
let Some(model_filename) = model_path.file_name() else {
pipeline_error::bail!("Missing filename in model path");
};
let Some(model_dir) = model_path.parent() else {
pipeline_error::bail!("Invalid model path");
};
GGUFLoaderBuilder::new(
chat_template,
None,
model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()],
GGUFSpecificConfig::default(),
no_kv_cache,
jinja_explicit,
)
.build()
} else if is_vision_model(display_name) {
let vlt = if is_gemma3(display_name) {
VisionLoaderType::Gemma3
} else if is_llama4(display_name) {
VisionLoaderType::Llama4
} else {
panic!("Unsupported vision model {display_name}");
};
VisionLoaderBuilder::new(
VisionSpecificConfig::default(),
chat_template,
tokenizer_json,
Some(model_path.display().to_string()),
jinja_explicit,
)
.build(Some(vlt))
} else {
// Load from a HF repo dir
NormalLoaderBuilder::new(
NormalSpecificConfig::default(),
chat_template,
tokenizer_json,
Some(model_path.display().to_string()),
no_kv_cache,
jinja_explicit,
)
.build(None)?
};
let mut max_seq_len = model.card().context_length as usize;
if max_seq_len == 0 {
tracing::info!("context_length is 0. Probably error reading from model.");
max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
}
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
Some(PagedAttentionConfig::new(
None, // Block size, default 32
MemoryGpuConfig::ContextSize(max_seq_len),
PagedCacheType::Auto,
)?)
} else {
None
};
let device_map_params = if is_vision_model(model.display_name()) {
AutoDeviceMapParams::Vision {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
max_image_shape: (0, 0),
max_num_images: 0,
}
} else {
AutoDeviceMapParams::Text {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}
};
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
TokenSource::None, // The model was already downloaded
&ModelDType::Auto,
&best_device()?,
false,
DeviceMapSetting::Auto(device_map_params),
if is_llama4(display_name) {
Some(IsqType::Q4K)
} else {
None
},
paged_attention_config,
)?;
let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
Some(conf) => conf.clone(),
None => {
anyhow::bail!("Failed loading model config");
}
};
SchedulerConfig::PagedAttentionMeta {
max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
config,
}
} else {
SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
}
};
// Create the MistralRs, which is a runner
let throughput_logging = false;
let search_embedding_model = None;
let builder = MistralRsBuilder::new(
pipeline.clone(),
scheduler,
throughput_logging,
search_embedding_model,
)
.with_prefix_cache_n(16);
let engine = MistralRsEngine {
mistralrs: builder.build().await,
context_length: max_seq_len,
display_name: display_name.to_string(),
};
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
let _ = engine.mistralrs.next_request_id();
// Perform warmup request
let (tx, mut rx) = channel(1);
let mistralrs_request_id = engine.mistralrs.next_request_id();
let warmup_request = Request::Normal(Box::new(NormalRequest {
id: mistralrs_request_id,
model_id: Some(display_name.to_string()),
messages: RequestMessage::Chat {
messages: vec![IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
(
"content".to_string(),
Either::Left(WARMUP_MESSAGE.to_string()),
),
])],
enable_thinking: Some(false),
},
sampling_params: SamplingParams::deterministic(),
response: tx,
return_logprobs: false,
is_streaming: false,
constraint: Constraint::None,
suffix: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
truncate_sequence: false,
}));
// Send warmup request and consume response
if let Ok(sender) = engine.mistralrs.get_sender(None)
&& let Ok(()) = sender.send(warmup_request).await
&& let Some(response) = rx.recv().await
{
match response.as_result() {
Ok(r) => {
tracing::debug!(mistralrs_request_id, "Warmup response: {r:?}");
}
Err(err) => {
tracing::error!(mistralrs_request_id, %err, "Failed converting response to result.");
}
}
}
Ok(engine)
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
let request_id = ctx.id().to_string();
let (tx, mut rx) = channel(10_000);
let mut messages = vec![];
for m in request.inner.messages {
let dynamo_async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
continue;
};
let dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
inner_m.content
else {
anyhow::bail!("Only Text type chat completion supported");
};
let r = IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
("content".to_string(), Either::Left(content)),
]);
messages.push(r);
}
if messages.is_empty() {
anyhow::bail!("Empty request");
}
let det = SamplingParams::deterministic();
// allow deprecated because max_tokens
#[allow(deprecated)]
let sampling_params = SamplingParams {
temperature: request
.inner
.temperature
.map(|t| t as f64)
.or(det.temperature),
top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
top_n_logprobs: request
.inner
.top_logprobs
.map(|t| t as usize)
.unwrap_or(det.top_n_logprobs),
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
repetition_penalty: det.repetition_penalty,
stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
max_len: {
let requested_max_tokens = request
.inner
.max_completion_tokens
.or(request.inner.max_tokens)
.map(|m| m as usize);
// Ensure max_len doesn't exceed context length
match requested_max_tokens {
Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request
.inner
.logit_bias
.map(to_logit_bias)
.or(det.logits_bias),
// These are not in async-openai yet
top_k: det.top_k,
min_p: det.min_p,
n_choices: 1,
dry_params: det.dry_params,
};
let mistralrs_request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(Box::new(NormalRequest {
id: mistralrs_request_id,
model_id: Some(self.display_name.clone()),
messages: RequestMessage::Chat {
messages,
enable_thinking: None,
},
sampling_params,
response: tx,
return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true,
constraint: Constraint::None,
suffix: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
truncate_sequence: false,
}));
self.mistralrs
.get_sender(None)?
.send(mistralrs_request)
.await?;
let output = stream! {
while let Some(response) = rx.recv().await {
let response = match response.as_result() {
Ok(r) => r,
Err(err) => {
tracing::error!(mistralrs_request_id, %err, "Failed converting mistralrs channel response to result.");
break;
}
};
match response {
ResponseOk::Chunk(c) => {
let Some(from_assistant) = c.choices[0].delta.content.clone() else {
tracing::warn!(mistralrs_request_id, "No content from mistralrs. Abandoning request.");
break;
};
let finish_reason = match &c.choices[0].finish_reason.as_deref() {
Some("stop") | Some("canceled") => {
Some(FinishReason::Stop)
}
Some("length") => {
Some(FinishReason::Length)
}
Some(s) => {
tracing::warn!(mistralrs_request_id, stop_reason = s, "Unknow stop reason");
Some(FinishReason::Stop)
}
None => None,
};
//tracing::trace!("from_assistant: {from_assistant}");
#[allow(deprecated)]
let delta = NvCreateChatCompletionStreamResponse {
id: format!("chatcmpl-{request_id}"),
choices: vec![dynamo_async_openai::types::ChatChoiceStream{
index: 0,
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{
//role: c.choices[0].delta.role,
role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some(dynamo_async_openai::types::ChatCompletionMessageContent::Text(from_assistant)),
tool_calls: None,
refusal: None,
function_call: None,
reasoning_content: None,
},
logprobs: None,
finish_reason,
stop_reason: None,
}],
model: c.model,
created: c.created as u32,
object: c.object.clone(),
usage: None,
system_fingerprint: Some(c.system_fingerprint),
service_tier: None,
nvext: None,
};
let ann = Annotated{
id: None,
data: Some(delta),
event: None,
comment: None,
};
yield ann;
if finish_reason.is_some() {
//tracing::trace!(mistralrs_request_id, "Finish reason: {finish_reason:?}");
break;
}
},
x => tracing::error!(mistralrs_request_id, "Unhandled. {x:?}"),
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: dynamo_async_openai::types::Stop) -> StopTokens {
match t {
dynamo_async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
dynamo_async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
}
}
/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
let mut out = HashMap::new();
for (key, value) in &lb {
let token_id: u32 = match key.parse() {
Ok(t) => t,
Err(err) => {
tracing::warn!(
"Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
);
return HashMap::new();
}
};
let Some(bias) = value.as_f64() else {
tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
return HashMap::new();
};
out.insert(token_id, bias as f32);
}
out
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
let (tx, mut rx) = channel(10_000);
let response_generator = request.response_generator(ctx.id().to_string());
let messages = RequestMessage::Completion {
text: prompt_to_string(&request.inner.prompt),
echo_prompt: false,
best_of: Some(1),
};
let det = SamplingParams::deterministic();
// allow deprecated because max_tokens
#[allow(deprecated)]
let sampling_params = SamplingParams {
temperature: request
.inner
.temperature
.map(|t| t as f64)
.or(det.temperature),
top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
top_n_logprobs: request
.inner
.logprobs
.map(|t| t as usize)
.unwrap_or(det.top_n_logprobs),
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
repetition_penalty: det.repetition_penalty,
stop_toks: request
.inner
.stop
.clone()
.map(to_stop_tokens)
.or(det.stop_toks),
max_len: {
let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);
// Ensure max_len doesn't exceed context length
match requested_max_tokens {
Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
None => det
.max_len
.map(|len| std::cmp::min(len, self.context_length)),
}
},
logits_bias: request
.inner
.logit_bias
.clone()
.map(to_logit_bias)
.or(det.logits_bias),
// These are not in async-openai yet
top_k: det.top_k,
min_p: det.min_p,
n_choices: 1,
dry_params: det.dry_params,
};
let mistralrs_request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(Box::new(NormalRequest {
id: mistralrs_request_id,
model_id: Some(self.display_name.clone()),
messages,
sampling_params,
response: tx,
return_logprobs: false,
is_streaming: true,
constraint: Constraint::None,
suffix: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
truncate_sequence: false,
}));
self.mistralrs
.get_sender(None)?
.send(mistralrs_request)
.await?;
let output = stream! {
while let Some(response) = rx.recv().await {
let response = match response.as_result() {
Ok(r) => r,
Err(err) => {
tracing::error!(mistralrs_request_id, %err, "Failed converting mistralrs channel response to result.");
break;
}
};
match response {
ResponseOk::CompletionChunk(c) => {
let from_assistant = c.choices[0].text.clone();
let finish_reason = match &c.choices[0].finish_reason.as_deref() {
Some("stop") | Some("canceled") => {
Some(FinishReason::Stop)
}
Some("length") => {
Some(FinishReason::Length)
}
Some(s) => {
tracing::warn!(mistralrs_request_id, stop_reason = s, "Unknow stop reason");
Some(FinishReason::Stop)
}
None => None,
};
#[allow(deprecated)]
let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
let ann = Annotated{
id: None,
data: Some(inner),
event: None,
comment: None,
};
yield ann;
if finish_reason.is_some() {
break;
}
},
x => tracing::error!(mistralrs_request_id, "Unhandled. {x:?}"),
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
fn is_vision_model(s: &str) -> bool {
is_gemma3(s) || is_llama4(s)
}
fn is_gemma3(s: &str) -> bool {
s.to_lowercase().contains("gemma-3")
}
fn is_llama4(s: &str) -> bool {
s.to_lowercase().contains("llama-4")
}
#[async_trait]
impl
AsyncEngine<
SingleIn<NvCreateEmbeddingRequest>,
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateEmbeddingRequest>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
unimplemented!()
}
}
......@@ -50,7 +50,7 @@ async fn main_loop(
}
let theme = dialoguer::theme::ColorfulTheme::default();
// Initial prompt is the pipe case: `echo "Hello" | dynamo-run ..`
// Initial prompt is from piped stdin.
// We run that single prompt and exit
let single = initial_prompt.is_some();
let mut history = dialoguer::BasicHistory::default();
......
......@@ -122,7 +122,6 @@ impl Default for LoggingConfig {
("tokenizers".to_string(), "error".to_string()),
("axum".to_string(), "error".to_string()),
("tonic".to_string(), "error".to_string()),
("mistralrs_core".to_string(), "error".to_string()),
("hf_hub".to_string(), "error".to_string()),
("opentelemetry".to_string(), "error".to_string()),
("opentelemetry-otlp".to_string(), "error".to_string()),
......
......@@ -24,12 +24,6 @@ echo "🚀 Starting dynamo disaggregated serving setup without LMCache:"
echo " Model: $MODEL_URL"
echo " Port: 8000"
echo " Mode: Disaggregated (prefill + decode workers)"
# Kill any existing dynamo processes
echo "🧹 Cleaning up any existing dynamo processes..."
pkill -f "dynamo-run" || true
sleep 2
echo "🔧 Starting dynamo disaggregated serving without LMCache..."
python -m dynamo.frontend &
......
......@@ -22,12 +22,6 @@ fi
echo "🚀 Starting dynamo setup without LMCache:"
echo " Model: $MODEL_URL"
echo " Port: 8000"
# Kill any existing dynamo processes
echo "🧹 Cleaning up any existing dynamo processes..."
pkill -f "dynamo-run" || true
sleep 2
echo "🔧 Starting dynamo worker without LMCache..."
python -m dynamo.frontend &
......
......@@ -25,12 +25,6 @@ echo " Model: $MODEL_URL"
echo " Port: 8000"
echo " Mode: Disaggregated (prefill + decode workers) + LMCache"
echo " !! Remember to kill the old dynamo processes otherwise the port will be busy !!"
# Kill any existing dynamo processes
echo "🧹 Cleaning up any existing dynamo processes..."
pkill -f "dynamo-run" || true
sleep 2
echo "🔧 Starting dynamo disaggregated serving with LMCache enabled..."
python -m dynamo.frontend &
......
......@@ -23,12 +23,6 @@ echo "🚀 Starting dynamo setup with LMCache:"
echo " Model: $MODEL_URL"
echo " Port: 8000"
echo " !! Remmber to kill the old dynamo processes other wise the port will be busy !! "
# Kill any existing dynamo processes
echo "🧹 Cleaning up any existing dynamo processes..."
pkill -f "dynamo-run" || true
sleep 2
echo "🔧 Starting dynamo worker with LMCache enabled..."
python -m dynamo.frontend &
......
......@@ -17,10 +17,6 @@ echo ""
cleanup() {
echo "🧹 Cleaning up running processes..."
# Kill any remaining dynamo processes
pkill -f "dynamo-run" || true
pkill -f "components/main.py" || true
# Stop docker services
docker compose -f ../../deploy/docker-compose.yml down 2>/dev/null || true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment