Commit 05765cd4 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(vllm,sglang): Let the engine enforce max tokens (#216)

Previously several parts of the stack ensured max tokens (for this single request) was set.

Now only text input sets it (to 8k). Everything else leaves as is, potentially blank. The engines themselves have very small defaults, 16 for vllm and 128 for sglang.

Also fix dynamo-run CUDA startup message to only print if we're using an engine that would benefit from it (mistralrs, llamacpp).
parent 8891aa0c
......@@ -24,6 +24,7 @@ use crate::input::common;
use crate::EngineConfig;
/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
const MAX_TOKENS: u32 = 8192;
pub async fn run(
......
......@@ -108,25 +108,6 @@ fn main() -> anyhow::Result<()> {
}
}
}
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
{
#[cfg(feature = "cuda")]
{
tracing::info!("CUDA on");
}
#[cfg(feature = "metal")]
{
tracing::info!("Metal on");
}
#[cfg(feature = "vulkan")]
{
tracing::info!("Vulkan on");
}
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
tracing::info!(
"CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance"
);
}
// max_worker_threads and max_blocking_threads from env vars or config file.
let rt_config = dynamo_runtime::RuntimeConfig::from_settings()?;
......@@ -190,6 +171,7 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
default_engine
}
};
print_cuda(&out_opt);
// Clap skips the first argument expecting it to be the binary name, so add it back
// Note `--model-path` has index=1 (in lib.rs) so that doesn't need a flag.
......@@ -208,3 +190,39 @@ async fn wrapper(runtime: dynamo_runtime::Runtime) -> anyhow::Result<()> {
)
.await
}
/// If the user will benefit from CUDA/Metal/Vulkan, remind them to build with it.
/// If they have it, celebrate!
// Only mistralrs and llamacpp need to be built with CUDA.
// The Python engines only need it at runtime.
#[cfg(any(feature = "mistralrs", feature = "llamacpp"))]
fn print_cuda(output: &Output) {
// These engines maybe be compiled in, but are they the chosen one?
match output {
#[cfg(feature = "mistralrs")]
Output::MistralRs => {}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {}
_ => {
return;
}
}
#[cfg(feature = "cuda")]
{
tracing::info!("CUDA on");
}
#[cfg(feature = "metal")]
{
tracing::info!("Metal on");
}
#[cfg(feature = "vulkan")]
{
tracing::info!("Vulkan on");
}
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
tracing::info!("CPU mode. Rebuild with `--features cuda|metal|vulkan` for better performance");
}
#[cfg(not(any(feature = "mistralrs", feature = "llamacpp")))]
fn print_cuda(_output: Output) {}
......@@ -142,16 +142,7 @@ impl
request: SingleIn<BackendInput>,
next: ServerStreamingEngine<BackendInput, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
// possible use the request
let mut stop_conditions = request.stop_conditions.clone();
// preprocessor should have set max_tokens
// assert!(stop_conditions.max_tokens.is_some());
if stop_conditions.max_tokens.is_none() {
log::warn!("max_tokens is not set in stop_conditions; fixme");
stop_conditions.max_tokens = Some(256);
}
let stop_conditions = request.stop_conditions.clone();
let next_stream = next.generate(request).await?;
let context = next_stream.context();
......@@ -265,9 +256,6 @@ pub struct Decoder {
// do not trigger stop conditions until at least this many tokens have been generated
min_tokens: u32,
// maximum number of tokens to generate - the llm engine should enforce this
max_tokens: u32,
// single tokens that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated
hidden_stop_ids: HashSet<TokenIdType>,
......@@ -372,7 +360,6 @@ impl Decoder {
//visible_stop_ids: HashSet::new(),
//visible_stop_sequences: Vec::new(),
min_tokens: stop_condition.min_tokens.unwrap_or(0),
max_tokens: stop_condition.max_tokens.expect("max_tokens is required"),
generated_tokens: 0,
jail: String::new(),
jail_max_bytes,
......@@ -407,14 +394,6 @@ impl Decoder {
));
}
// next check max_tokens limit
if self.generated_tokens >= self.max_tokens {
return Ok(StepResult::with_stop_trigger(
token,
StopTrigger::MaxTokensLimit,
));
}
// check stop sequences - the jail will always hold at least the largest stop sequence
// if jail_max_bytes is 0, then there are no stop sequences
if self.jail_max_bytes > 0 {
......
......@@ -47,9 +47,6 @@ use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
use crate::protocols::TokenIdType;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the sglang sub-process to stop after we send it a KILL
const SGLANG_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
......@@ -99,7 +96,6 @@ pub struct WorkRequest {
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: Option<i32>,
max_tokens: i32,
}
/// Python imports
......@@ -487,7 +483,7 @@ async fn start_sglang(
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("{LOG_PREFIX}{tp_rank} {line}");
tracing::debug!("{LOG_PREFIX}{tp_rank} {line}");
}
});
......@@ -505,7 +501,7 @@ async fn start_sglang(
}
3 => {
// Normal log line. Skip Python's date/time
tracing::info!("{LOG_PREFIX}{tp_rank} {}", &cap[2]);
tracing::debug!("{LOG_PREFIX}{tp_rank} {}", &cap[2]);
}
x => {
unreachable!("sglang log re only has two capture groups, so {x} entries is impossible");
......@@ -608,21 +604,18 @@ async fn input_loop(
.temperature
.unwrap_or(0.0)
.into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
tracing::trace!("Received work request: {request_id}");
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let mut sp_kwargs = vec![("temperature", py_temp)];
if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
let sp_kwargs = [("temperature", py_temp), ("max_new_tokens", py_max_tokens)]
.into_py_dict(py)
.unwrap();
// sglang defaults this to 128
sp_kwargs.push(("max_new_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports
.sampling_params_type
.call(py, (), Some(&sp_kwargs))
......@@ -666,7 +659,6 @@ async fn input_loop(
});
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
max_tokens: max_tokens as i32,
num_output_tokens_so_far: None,
};
active_requests
......@@ -674,7 +666,6 @@ async fn input_loop(
.await
.insert(request_id, new_active_request);
//if let Err(err) = input_socket.send(vec![pickled_req].into()).await {
if let Err(err) = input_socket.send(pickled_req.into()).await {
tracing::error!("Error sending new request to sglang over zmq: {err}");
}
......@@ -732,6 +723,11 @@ async fn output_loop(
let token_ids: Vec<TokenIdType> = if sglang_finish_reason.is_none() {
req_out.decode_ids[idx][previous_total_toks..].into()
} else {
tracing::trace!(
req_id,
?sglang_finish_reason,
"finished with finish reason"
);
// Request is over, sglang says so.
// The last token is the eos_token, don't forward it
remove_after.push(req_id.clone());
......@@ -746,14 +742,7 @@ async fn output_loop(
finish_reason: sglang_finish_reason.map(|x| x.into()),
};
active.num_output_tokens_so_far = Some(next_total_toks);
let out = if next_total_toks <= active.max_tokens {
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after.push(req_id.clone());
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
let _ = active.tx.send(Annotated::from_data(out)).await;
}
None => {
// sglang sends the finish response twice, I don't know why
......
......@@ -39,9 +39,6 @@ use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
......@@ -82,7 +79,6 @@ pub struct WorkRequest {
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: usize,
max_tokens: usize,
}
/// Python imports
......@@ -623,19 +619,17 @@ async fn input_loop(
.temperature
.unwrap_or(0.0)
.into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let mut sp_kwargs = vec![("temperature", py_temp)];
if let Some(max_tokens) = work_request.request.stop_conditions.max_tokens {
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
let sp_kwargs = [("temperature", py_temp), ("max_tokens", py_max_tokens)]
.into_py_dict(py)
.unwrap();
// vllm defaults this to 16
sp_kwargs.push(("max_tokens", py_max_tokens));
}
let sp_kwargs = sp_kwargs.into_py_dict(py).unwrap();
let sampling_params = py_imports
.sample_params_type
.call(py, (), Some(&sp_kwargs))
......@@ -668,7 +662,6 @@ async fn input_loop(
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
max_tokens: max_tokens as usize,
num_output_tokens_so_far: 0,
};
active_requests
......@@ -728,6 +721,7 @@ async fn output_loop(
if req_out.finished {
// The last token is the eos_token, don't forward it
// TODO: Look at req_out.finish_reason (Option<String>) and set out correctly.
let out = Annotated::from_data(LLMEngineOutput::stop());
let maybe_active = active_requests.lock().await.remove(&req_out.request_id);
match maybe_active {
......@@ -744,7 +738,6 @@ async fn output_loop(
continue;
}
let mut remove_after = false;
for vllm_output in req_out.outputs.into_iter() {
let next_total_toks = vllm_output.token_ids.len();
......@@ -752,23 +745,13 @@ async fn output_loop(
Some(active) => {
let out = from_vllm(vllm_output, active.num_output_tokens_so_far);
active.num_output_tokens_so_far = next_total_toks;
let out = if active.num_output_tokens_so_far <= active.max_tokens {
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after = true;
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
let _ = active.tx.send(Annotated::from_data(out)).await;
}
None => {
tracing::warn!(req_out.request_id, "Missing active request");
}
}
}
if remove_after {
let _ = active_requests.lock().await.remove(&req_out.request_id);
}
}
}
......
......@@ -143,12 +143,6 @@ impl OpenAIPreprocessor {
}
let mut stop_conditions = request.extract_stop_conditions()?;
// todo - pull this from the mdc default sampling/stop params
if stop_conditions.max_tokens.is_none() {
stop_conditions.max_tokens = Some(64);
}
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
......
......@@ -138,15 +138,11 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
/// providing access to stop conditions that control chat completion behavior.
#[allow(deprecated)]
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the maximum number of tokens allowed in the response.
///
/// # Note
/// This field is deprecated in favor of `max_completion_tokens`.
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens
self.inner.max_completion_tokens.or(self.inner.max_tokens)
}
/// Retrieves the minimum number of tokens required in the response.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment