Commit 404a78e9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(mistralrs): Let the engine enforce max tokens (#134)

Previously we tokenized and counted tokens to stop when max tokens was reached. Now we let the mistral.rs engine do it which saves the extra tokenization step.

Also dynamo-run prints which engines are compiled in in help message, and some minor lint fixes.
parent 941032da
......@@ -114,7 +114,6 @@ pub async fn run(
.await
}
#[allow(deprecated)]
async fn main_loop(
cancel_token: CancellationToken,
service_name: &str,
......@@ -172,7 +171,9 @@ async fn main_loop(
.messages(messages.clone())
.model(service_name)
.stream(true)
.max_tokens(MAX_TOKENS)
.max_completion_tokens(MAX_TOKENS)
.temperature(0.7)
.n(1) // only generate one response
.build()?;
// TODO We cannot set min_tokens with async-openai
......@@ -190,6 +191,9 @@ async fn main_loop(
let mut stdout = std::io::stdout();
let mut assistant_message = String::new();
while let Some(item) = stream.next().await {
if cancel_token.is_cancelled() {
break;
}
match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => {
// Normal case
......@@ -226,15 +230,10 @@ async fn main_loop(
assistant_message,
);
// ALLOW: function_call is deprecated
let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
async_openai::types::ChatCompletionRequestAssistantMessage {
content: Some(assistant_content),
refusal: None,
name: None,
audio: None,
tool_calls: None,
function_call: None,
..Default::default()
},
);
messages.push(assistant_message);
......
......@@ -28,7 +28,6 @@ Example:
- cd target/release
- ./dynamo-run hf_checkouts/Llama-3.2-3B-Instruct/
- OR: ./dynamo-run Llama-3.2-1B-Instruct-Q4_K_M.gguf
"#;
const ZMQ_SOCKET_PREFIX: &str = "dyn";
......@@ -126,6 +125,11 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
if args.is_empty() || args[0] == "-h" || args[0] == "--help" {
println!("{USAGE}");
println!("{HELP}");
println!(
"Available engines: {}",
Output::available_engines().join(", ")
);
return Ok(());
}
for arg in env::args().skip(1).take(2) {
......
......@@ -117,6 +117,9 @@ pub enum Output {
/// tokens. We do the pre-processing.
#[cfg(feature = "python")]
PythonTok(String),
//
// DEVELOPER NOTE
// If you add an engine add it to `available_engines` below, and to Default if it makes sense
}
impl TryFrom<&str> for Output {
......@@ -192,10 +195,10 @@ impl fmt::Display for Output {
Output::Endpoint(path) => path,
#[cfg(feature = "python")]
Output::PythonStr(path) => path,
Output::PythonStr(_) => "pystr",
#[cfg(feature = "python")]
Output::PythonTok(path) => path,
Output::PythonTok(_) => "pytok",
};
write!(f, "{s}")
}
......@@ -233,3 +236,42 @@ impl Default for Output {
out
}
}
impl Output {
#[allow(unused_mut)]
pub fn available_engines() -> Vec<String> {
let mut out = vec!["echo_core".to_string(), "echo_full".to_string()];
#[cfg(feature = "mistralrs")]
{
out.push(Output::MistralRs.to_string());
}
#[cfg(feature = "llamacpp")]
{
out.push(Output::LlamaCpp.to_string());
}
#[cfg(feature = "sglang")]
{
out.push(Output::SgLang.to_string());
}
#[cfg(feature = "vllm")]
{
out.push(Output::Vllm.to_string());
}
#[cfg(feature = "python")]
{
out.push(Output::PythonStr("file.py".to_string()).to_string());
out.push(Output::PythonTok("file.py".to_string()).to_string());
}
#[cfg(feature = "trtllm")]
{
out.push(Output::TrtLLM.to_string());
}
out
}
}
......@@ -14,7 +14,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::{cmp::min, env, num::NonZero, path::Path, sync::Arc};
use std::{num::NonZero, path::Path, sync::Arc};
use async_openai::types::FinishReason;
use async_stream::stream;
......@@ -25,8 +25,7 @@ use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens,
TokenSource,
Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
};
use tokio::sync::mpsc::channel;
......@@ -40,15 +39,8 @@ use crate::protocols::openai::chat_completions::{
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192;
/// TODO: tune
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;
/// The environment variable which can hold the Hugging Face token, if any, in order
const HF_TOKEN_VARS: [&str; 3] = ["HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN"];
pub async fn make_engine(
gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
......@@ -71,22 +63,10 @@ fn best_device() -> pipeline_error::Result<Device> {
struct MistralRsEngine {
mistralrs: Arc<MistralRs>,
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync + 'static>>,
}
impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let mut hf_token_source = TokenSource::CacheToken;
// We might be trying to download a repo from Hugging Face. See if we have a token.
if !model_path.exists() {
for v_name in HF_TOKEN_VARS {
if env::var(v_name).is_ok() {
tracing::debug!("Using Hugging Face token from {v_name}");
hf_token_source = TokenSource::EnvVar(v_name.to_string());
break;
}
}
}
let loader = if model_path.is_file() {
// Load from a GGUF
let Some(model_filename) = model_path.file_name() else {
......@@ -142,7 +122,7 @@ impl MistralRsEngine {
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
hf_token_source,
TokenSource::None, // The model was already downloaded
&ModelDType::Auto,
&best_device()?,
false,
......@@ -176,7 +156,6 @@ impl MistralRsEngine {
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
Ok(MistralRsEngine {
mistralrs: builder.build(),
pipeline,
})
}
}
......@@ -196,27 +175,16 @@ impl
let (request, context) = request.transfer(());
let ctx = context.context();
let (tx, mut rx) = channel(10_000);
let maybe_tok = self.pipeline.lock().await.tokenizer();
let mut prompt_tokens = 0i32;
let mut messages = vec![];
for m in request.inner.messages {
let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
continue;
};
let content = match inner_m.content {
async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
if let Some(tok) = maybe_tok.as_ref() {
prompt_tokens = tok
.encode(prompt.clone(), false)
.map(|e| e.len() as i32)
.unwrap_or(0);
}
prompt
}
_ => {
anyhow::bail!("Only Text type is supported");
}
let async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
inner_m.content
else {
anyhow::bail!("Only Text type chat completion supported");
};
let r = IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
......@@ -227,17 +195,10 @@ impl
if messages.is_empty() {
anyhow::bail!("Empty request");
}
// TODO tracing::trace print the latest prompt, which should be the last message at user
// level.
//tracing::info!(prompt_tokens, "Received prompt");
let limit = DEFAULT_MAX_TOKENS - prompt_tokens;
#[allow(deprecated)]
let max_output_tokens = min(
request.inner.max_tokens.map(|x| x as i32).unwrap_or(limit),
limit,
);
let det = SamplingParams::deterministic();
// allow deprecated because max_tokens
#[allow(deprecated)]
let sampling_params = SamplingParams {
temperature: request
.inner
......@@ -256,6 +217,7 @@ impl
max_len: request
.inner
.max_completion_tokens
.or(request.inner.max_tokens)
.map(|m| m as usize)
.or(det.max_len),
logits_bias: request
......@@ -287,7 +249,6 @@ impl
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
let mut used_output_tokens = 0;
let output = stream! {
while let Some(response) = rx.recv().await {
let response = match response.as_result() {
......@@ -303,18 +264,17 @@ impl
tracing::warn!("No content from mistralrs. Abandoning request.");
break;
};
if let Some(tok) = maybe_tok.as_ref() {
used_output_tokens += tok
.encode(from_assistant.clone(), false)
.map(|e| e.len() as i32)
.unwrap_or(0);
}
let finish_reason = match &c.choices[0].finish_reason {
Some(_fr) => Some(FinishReason::Stop), //Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::Stop)),
None if used_output_tokens >= max_output_tokens => {
tracing::debug!(used_output_tokens, max_output_tokens, "Met or exceed max_tokens. Stopping.");
let finish_reason = match &c.choices[0].finish_reason.as_deref() {
Some("stop") | Some("canceled") => {
Some(FinishReason::Stop)
}
Some("length") => {
Some(FinishReason::Length)
}
Some(s) => {
tracing::warn!(stop_reason = s, "Unknow stop reason");
Some(FinishReason::Stop)
}
None => None,
};
//tracing::trace!("from_assistant: {from_assistant}");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment