Commit 05765cd4 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(vllm,sglang): Let the engine enforce max tokens (#216)

Previously several parts of the stack ensured max tokens (for this single request) was set.

Now only text input sets it (to 8k). Everything else leaves as is, potentially blank. The engines themselves have very small defaults, 16 for vllm and 128 for sglang.

Also fix dynamo-run CUDA startup message to only print if we're using an engine that would benefit from it (mistralrs, llamacpp).
parent 8891aa0c
...@@ -24,6 +24,7 @@ use crate::input::common; ...@@ -24,6 +24,7 @@ use crate::input::common;
use crate::EngineConfig; use crate::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
const MAX_TOKENS: u32 = 8192; const MAX_TOKENS: u32 = 8192;
pub async fn run( pub async fn run(
......
...@@ -108,25 +108,6 @@ fn main() -> anyhow::Result<()> { ...@@ -108,25 +108,6 @@ fn main() -> anyhow::Result<()> {
} }
} }
} }
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
{
#[cfg(feature = "cuda")]
{
tracing::info!("CUDA on");
}
#[cfg(feature = "metal")]
{
tracing::info!("Metal on");
}
#[cfg(feature = "vulkan")]
{
tracing::info!("Vulkan on");
}
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
tracing::info!(
"CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance"
);
}
// max_worker_threads and max_blocking_threads from env vars or config file. // max_worker_threads and max_blocking_threads from env vars or config file.
let rt_config = dynamo_runtime::RuntimeConfig::from_settings()?; let rt_config = dynamo_runtime::RuntimeConfig::from_settings()?;
...@@ -190,6 +171,7 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> { ...@@ -190,6 +171,7 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
default_engine default_engine
} }
}; };
print_cuda(&out_opt);
// Clap skips the first argument expecting it to be the binary name, so add it back // Clap skips the first argument expecting it to be the binary name, so add it back
// Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag. // Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag.
...@@ -208,3 +190,39 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> { ...@@ -208,3 +190,39 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
) )
.await .await
} }
/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs and llamacpp need to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
fn print_cuda(output: &Output) {
// These engines maybe be compiled in, but are they the chosen one?
match output {
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {}
_ => {
return;
}
}
#[cfg(feature = "cuda")]
{
tracing::info!("CUDA on");
}
#[cfg(feature = "metal")]
{
tracing::info!("Metal on");
}
#[cfg(feature = "vulkan")]
{
tracing::info!("Vulkan on");
}
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
tracing::info!("CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance");
}
#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: Output) {}
...@@ -142,16 +142,7 @@ impl ...@@ -142,16 +142,7 @@ impl
request: SingleIn<BackendInput>, request: SingleIn<BackendInput>,
next: ServerStreamingEngine<BackendInput, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<BackendInput, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> { ) -> Result<ManyOut<Annotated<BackendOutput>>> {
// possible use the request let stop_conditions = request.stop_conditions.clone();
let mut stop_conditions = request.stop_conditions.clone();
// preprocessor should have set max_tokens
// assert!(stop_conditions.max_tokens.is_some());
if stop_conditions.max_tokens.is_none() {
log::warn!("max_tokens is not set in stop_conditions; fixme");
stop_conditions.max_tokens = Some(256);
}
let next_stream = next.generate(request).await?; let next_stream = next.generate(request).await?;
let context = next_stream.context(); let context = next_stream.context();
...@@ -265,9 +256,6 @@ pub struct Decoder { ...@@ -265,9 +256,6 @@ pub struct Decoder {
// do not trigger stop conditions until at least this many tokens have been generated // do not trigger stop conditions until at least this many tokens have been generated
min_tokens: u32, min_tokens: u32,
// maximum number of tokens to generate - the llm engine should enforce this
max_tokens: u32,
// single tokens that if found in the response will trigger a stop condition after the // single tokens that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated // minimum number of tokens have been generated
hidden_stop_ids: HashSet<TokenIdType>, hidden_stop_ids: HashSet<TokenIdType>,
...@@ -372,7 +360,6 @@ impl Decoder { ...@@ -372,7 +360,6 @@ impl Decoder {
//visible_stop_ids: HashSet::new(), //visible_stop_ids: HashSet::new(),
//visible_stop_sequences: Vec::new(), //visible_stop_sequences: Vec::new(),
min_tokens: stop_condition.min_tokens.unwrap_or(0), min_tokens: stop_condition.min_tokens.unwrap_or(0),
max_tokens: stop_condition.max_tokens.expect("max_tokens is required"),
generated_tokens: 0, generated_tokens: 0,
jail: String::new(), jail: String::new(),
jail_max_bytes, jail_max_bytes,
...@@ -407,14 +394,6 @@ impl Decoder { ...@@ -407,14 +394,6 @@ impl Decoder {
)); ));
} }
// next check max_tokens limit
if self.generated_tokens >= self.max_tokens {
return Ok(StepResult::with_stop_trigger(
token,
StopTrigger::MaxTokensLimit,
));
}
// check stop sequences - the jail will always hold at least the largest stop sequence // check stop sequences - the jail will always hold at least the largest stop sequence
// if jail_max_bytes is 0, then there are no stop sequences // if jail_max_bytes is 0, then there are no stop sequences
if self.jail_max_bytes > 0 { if self.jail_max_bytes > 0 {
......
...@@ -47,9 +47,6 @@ use crate::protocols::common::preprocessor::PreprocessedRequest; ...@@ -47,9 +47,6 @@ use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason; use crate::protocols::common::FinishReason;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the sglang sub-process to stop after we send it a KILL /// Wait this long for the sglang sub-process to stop after we send it a KILL
const SGLANG_STOP_TIMEOUT: Duration = Duration::from_millis(1500); const SGLANG_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
...@@ -99,7 +96,6 @@ pub struct WorkRequest { ...@@ -99,7 +96,6 @@ pub struct WorkRequest {
struct ActiveRequest { struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>, tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: Option<i32>, num_output_tokens_so_far: Option<i32>,
max_tokens: i32,
} }
/// Python imports /// Python imports
...@@ -487,7 +483,7 @@ async fn start_sglang( ...@@ -487,7 +483,7 @@ async fn start_sglang(
tokio::spawn(async move { tokio::spawn(async move {
let mut lines = stdout.lines(); let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await { while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("{LOG_PREFIX}{tp_rank} {line}"); tracing::debug!("{LOG_PREFIX}{tp_rank} {line}");
} }
}); });
...@@ -505,7 +501,7 @@ async fn start_sglang( ...@@ -505,7 +501,7 @@ async fn start_sglang(
} }
3 => { 3 => {
// Normal log line. Skip Python's date/time // Normal log line. Skip Python's date/time
tracing::info!("{LOG_PREFIX}{tp_rank} {}", &cap[2]); tracing::debug!("{LOG_PREFIX}{tp_rank} {}", &cap[2]);
} }
x => { x => {
unreachable!("sglang log re only has two capture groups, so {x} entries is impossible"); unreachable!("sglang log re only has two capture groups, so {x} entries is impossible");
...@@ -608,21 +604,18 @@ async fn input_loop( ...@@ -608,21 +604,18 @@ async fn input_loop(
.temperature .temperature
.unwrap_or(0.0) .unwrap_or(0.0)
.into(); .into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
tracing::trace!("Received work request: {request_id}"); tracing::trace!("Received work request: {request_id}");
// Parts that don't change // Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| { let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into(); let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into(); let mut sp_kwargs = vec![("temperature", py_temp)];
let sp_kwargs = [("temperature", py_temp), ("max_new_tokens", py_max_tokens)] if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
.into_py_dict(py) let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
.unwrap(); // sglang defaults this to 128
sp_kwargs.push(("max_new_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports let sampling_params = py_imports
.sampling_params_type .sampling_params_type
.call(py, (), Some(&sp_kwargs)) .call(py, (), Some(&sp_kwargs))
...@@ -666,7 +659,6 @@ async fn input_loop( ...@@ -666,7 +659,6 @@ async fn input_loop(
}); });
let new_active_request = ActiveRequest { let new_active_request = ActiveRequest {
tx: work_request.response_channel, tx: work_request.response_channel,
max_tokens: max_tokens as i32,
num_output_tokens_so_far: None, num_output_tokens_so_far: None,
}; };
active_requests active_requests
...@@ -674,7 +666,6 @@ async fn input_loop( ...@@ -674,7 +666,6 @@ async fn input_loop(
.await .await
.insert(request_id, new_active_request); .insert(request_id, new_active_request);
//if let Err(err) = input_socket.send(vec![pickled_req].into()).await {
if let Err(err) = input_socket.send(pickled_req.into()).await { if let Err(err) = input_socket.send(pickled_req.into()).await {
tracing::error!("Error sending new request to sglang over zmq: {err}"); tracing::error!("Error sending new request to sglang over zmq: {err}");
} }
...@@ -732,6 +723,11 @@ async fn output_loop( ...@@ -732,6 +723,11 @@ async fn output_loop(
let token_ids: Vec<TokenIdType> = if sglang_finish_reason.is_none() { let token_ids: Vec<TokenIdType> = if sglang_finish_reason.is_none() {
req_out.decode_ids[idx][previous_total_toks..].into() req_out.decode_ids[idx][previous_total_toks..].into()
} else { } else {
tracing::trace!(
req_id,
?sglang_finish_reason,
"finished with finish reason"
);
// Request is over, sglang says so. // Request is over, sglang says so.
// The last token is the eos_token, don't forward it // The last token is the eos_token, don't forward it
remove_after.push(req_id.clone()); remove_after.push(req_id.clone());
...@@ -746,14 +742,7 @@ async fn output_loop( ...@@ -746,14 +742,7 @@ async fn output_loop(
finish_reason: sglang_finish_reason.map(|x| x.into()), finish_reason: sglang_finish_reason.map(|x| x.into()),
}; };
active.num_output_tokens_so_far = Some(next_total_toks); active.num_output_tokens_so_far = Some(next_total_toks);
let out = if next_total_toks <= active.max_tokens { let _ = active.tx.send(Annotated::from_data(out)).await;
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after.push(req_id.clone());
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
} }
None => { None => {
// sglang sends the finish response twice, I don't know why // sglang sends the finish response twice, I don't know why
......
...@@ -39,9 +39,6 @@ use crate::protocols::common::llm_backend::LLMEngineOutput; ...@@ -39,9 +39,6 @@ use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest; use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason; use crate::protocols::common::FinishReason;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the vllm sub-process to stop after we send it a KILL /// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500); const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
...@@ -82,7 +79,6 @@ pub struct WorkRequest { ...@@ -82,7 +79,6 @@ pub struct WorkRequest {
struct ActiveRequest { struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>, tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: usize, num_output_tokens_so_far: usize,
max_tokens: usize,
} }
/// Python imports /// Python imports
...@@ -623,19 +619,17 @@ async fn input_loop( ...@@ -623,19 +619,17 @@ async fn input_loop(
.temperature .temperature
.unwrap_or(0.0) .unwrap_or(0.0)
.into(); .into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
// Parts that don't change // Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| { let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into(); let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into(); let mut sp_kwargs = vec![("temperature", py_temp)];
let sp_kwargs = [("temperature", py_temp), ("max_tokens", py_max_tokens)] if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
.into_py_dict(py) let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
.unwrap(); // vllm defaults this to 16
sp_kwargs.push(("max_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports let sampling_params = py_imports
.sample_params_type .sample_params_type
.call(py, (), Some(&sp_kwargs)) .call(py, (), Some(&sp_kwargs))
...@@ -668,7 +662,6 @@ async fn input_loop( ...@@ -668,7 +662,6 @@ async fn input_loop(
let new_active_request = ActiveRequest { let new_active_request = ActiveRequest {
tx: work_request.response_channel, tx: work_request.response_channel,
max_tokens: max_tokens as usize,
num_output_tokens_so_far: 0, num_output_tokens_so_far: 0,
}; };
active_requests active_requests
...@@ -728,6 +721,7 @@ async fn output_loop( ...@@ -728,6 +721,7 @@ async fn output_loop(
if req_out.finished { if req_out.finished {
// The last token is the eos_token, don't forward it // The last token is the eos_token, don't forward it
// TODO: Look at req_out.finish_reason (Option<String>) and set out correctly.
let out = Annotated::from_data(LLMEngineOutput::stop()); let out = Annotated::from_data(LLMEngineOutput::stop());
let maybe_active = active_requests.lock().await.remove(&req_out.request_id); let maybe_active = active_requests.lock().await.remove(&req_out.request_id);
match maybe_active { match maybe_active {
...@@ -744,7 +738,6 @@ async fn output_loop( ...@@ -744,7 +738,6 @@ async fn output_loop(
continue; continue;
} }
let mut remove_after = false;
for vllm_output in req_out.outputs.into_iter() { for vllm_output in req_out.outputs.into_iter() {
let next_total_toks = vllm_output.token_ids.len(); let next_total_toks = vllm_output.token_ids.len();
...@@ -752,23 +745,13 @@ async fn output_loop( ...@@ -752,23 +745,13 @@ async fn output_loop(
Some(active) => { Some(active) => {
let out = from_vllm(vllm_output, active.num_output_tokens_so_far); let out = from_vllm(vllm_output, active.num_output_tokens_so_far);
active.num_output_tokens_so_far = next_total_toks; active.num_output_tokens_so_far = next_total_toks;
let out = if active.num_output_tokens_so_far <= active.max_tokens { let _ = active.tx.send(Annotated::from_data(out)).await;
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after = true;
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
} }
None => { None => {
tracing::warn!(req_out.request_id, "Missing active request"); tracing::warn!(req_out.request_id, "Missing active request");
} }
} }
} }
if remove_after {
let _ = active_requests.lock().await.remove(&req_out.request_id);
}
} }
} }
......
...@@ -143,12 +143,6 @@ impl OpenAIPreprocessor { ...@@ -143,12 +143,6 @@ impl OpenAIPreprocessor {
} }
let mut stop_conditions = request.extract_stop_conditions()?; let mut stop_conditions = request.extract_stop_conditions()?;
// todo - pull this from the mdc default sampling/stop params
if stop_conditions.max_tokens.is_none() {
stop_conditions.max_tokens = Some(64);
}
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden { if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() { for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) { if !stop_tokens.contains(&eos_token) {
......
...@@ -138,15 +138,11 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { ...@@ -138,15 +138,11 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
/// providing access to stop conditions that control chat completion behavior. /// providing access to stop conditions that control chat completion behavior.
#[allow(deprecated)]
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the maximum number of tokens allowed in the response. /// Retrieves the maximum number of tokens allowed in the response.
/// #[allow(deprecated)]
/// # Note
/// This field is deprecated in favor of `max_completion_tokens`.
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens self.inner.max_completion_tokens.or(self.inner.max_tokens)
self.inner.max_tokens
} }
/// Retrieves the minimum number of tokens required in the response. /// Retrieves the minimum number of tokens required in the response.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment