Commit fd95f37b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(mistralrs): Disable paged attention (#234)

Under load it sometimes drops a request. The request gets added to the batch (sequence) and immediately gets a FinishReason Stop. Not sure why. It doesn't happen with the default scheduler (non-paged attention), so switch to that for now.
parent 48a59890
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
use anyhow::Context as _; use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_llm::model_card::model::ModelDeploymentCard; use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::types::openai::chat_completions::{ use dynamo_llm::types::openai::chat_completions::{
...@@ -37,7 +38,6 @@ use crate::EngineConfig; ...@@ -37,7 +38,6 @@ use crate::EngineConfig;
const MAX_TOKENS: u32 = 8192; const MAX_TOKENS: u32 = 8192;
const OUTPUT_FILENAME: &str = "output.jsonl"; const OUTPUT_FILENAME: &str = "output.jsonl";
const DUMMY_MODEL_NAME: &str = "dynamo-run-batch";
#[derive(Serialize, Deserialize, Default, Debug)] #[derive(Serialize, Deserialize, Default, Debug)]
struct Entry { struct Entry {
...@@ -54,6 +54,12 @@ struct Entry { ...@@ -54,6 +54,12 @@ struct Entry {
#[serde(default)] #[serde(default)]
elapsed_ms: usize, elapsed_ms: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
finish_reason: Option<FinishReason>,
#[serde(skip, default)]
request_id: usize,
} }
pub async fn run( pub async fn run(
...@@ -71,29 +77,21 @@ pub async fn run( ...@@ -71,29 +77,21 @@ pub async fn run(
); );
} }
let (_service_name, engine, _inspect_template) = let (service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?; common::prepare_engine(runtime.clone(), engine_config).await?;
let service_name_ref = Arc::new(service_name);
let pre_processor = if let Some(card) = maybe_card { let pre_processor = if let Some(card) = maybe_card {
Some(OpenAIPreprocessor::new(card).await?) Some(OpenAIPreprocessor::new(card).await?)
} else { } else {
None None
}; };
let (all_finish_tx, all_finish_rx) = tokio::sync::oneshot::channel();
let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64); let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
let dw_cancel_token = cancel_token.clone(); let dw_cancel_token = cancel_token.clone();
let mut output_file = input_jsonl.clone(); let mut output_file = input_jsonl.clone();
output_file.set_file_name(OUTPUT_FILENAME); output_file.set_file_name(OUTPUT_FILENAME);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = output_writer( if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
dw_cancel_token,
done_entries_rx,
&output_file,
all_finish_tx,
)
.await
{
tracing::error!(%err, "Failed writing output to {}", output_file.display()); tracing::error!(%err, "Failed writing output to {}", output_file.display());
} }
}); });
...@@ -125,15 +123,18 @@ pub async fn run( ...@@ -125,15 +123,18 @@ pub async fn run(
anyhow::bail!("Error parsing entry: '{line}'. {err}"); anyhow::bail!("Error parsing entry: '{line}'. {err}");
} }
}; };
entry.request_id = request_id;
let engine = engine.clone(); let engine = engine.clone();
let pre_processor = pre_processor.clone(); let pre_processor = pre_processor.clone();
let tokens_in = tokens_in.clone(); let tokens_in = tokens_in.clone();
let tokens_out = tokens_out.clone(); let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone(); let done_entries_tx = done_entries_tx.clone();
let service_name_ref = service_name_ref.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let local_start = Instant::now(); let local_start = Instant::now();
let response = match evaluate(request_id, engine, &entry.text).await { let response =
match evaluate(request_id, service_name_ref.as_str(), engine, &mut entry).await {
Ok(r) => r, Ok(r) => r,
Err(err) => { Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt"); tracing::error!(%err, entry.text, "Failed evaluating prompt");
...@@ -175,8 +176,6 @@ pub async fn run( ...@@ -175,8 +176,6 @@ pub async fn run(
} }
_ = futures::future::join_all(handles) => { _ = futures::future::join_all(handles) => {
} }
_ = all_finish_rx => {
}
} }
let elapsed = Instant::now() - start; let elapsed = Instant::now() - start;
let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64); let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
...@@ -198,23 +197,24 @@ pub async fn run( ...@@ -198,23 +197,24 @@ pub async fn run(
// Run a single prompt through the engine // Run a single prompt through the engine
async fn evaluate( async fn evaluate(
_request_id: usize, request_id: usize,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
prompt: &str, entry: &mut Entry,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let user_message = async_openai::types::ChatCompletionRequestMessage::User( let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
prompt.to_string(), entry.text.clone(),
), ),
name: None, name: None,
}, },
); );
let inner = async_openai::types::CreateChatCompletionRequestArgs::default() let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message]) .messages(vec![user_message])
.model(DUMMY_MODEL_NAME) .model(service_name)
.stream(true) .stream(true)
.max_tokens(MAX_TOKENS) .max_completion_tokens(MAX_TOKENS)
.build()?; .build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None }; let req = NvCreateChatCompletionRequest { inner, nvext: None };
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
...@@ -223,24 +223,30 @@ async fn evaluate( ...@@ -223,24 +223,30 @@ async fn evaluate(
match (item.data.as_ref(), item.event.as_deref()) { match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => { (Some(data), _) => {
// Normal case // Normal case
let entry = data.inner.choices.first(); let choice = data.inner.choices.first();
let chat_comp = entry.as_ref().unwrap(); let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content { if let Some(c) = &chat_comp.delta.content {
output += c; output += c;
} }
entry.finish_reason = chat_comp.finish_reason;
if chat_comp.finish_reason.is_some() { if chat_comp.finish_reason.is_some() {
tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap()); tracing::trace!(
request_id,
"finish reason: {:?}",
chat_comp.finish_reason.unwrap()
);
break; break;
} }
} }
(None, Some("error")) => { (None, Some("error")) => {
tracing::error!(request_id, "the error case");
// There's only one error but we loop in case that changes // There's only one error but we loop in case that changes
for err in item.comment.unwrap_or_default() { for err in item.comment.unwrap_or_default() {
tracing::error!("Engine error: {err}"); tracing::error!(request_id, "Engine error: {err}");
} }
} }
(None, Some(annotation)) => { (None, Some(annotation)) => {
tracing::debug!("Annotation. {annotation}: {:?}", item.comment); tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
} }
_ => { _ => {
unreachable!("Event from engine with no data, no error, no annotation."); unreachable!("Event from engine with no data, no error, no annotation.");
...@@ -254,22 +260,20 @@ async fn output_writer( ...@@ -254,22 +260,20 @@ async fn output_writer(
cancel_token: CancellationToken, cancel_token: CancellationToken,
mut entries_rx: tokio::sync::mpsc::Receiver<Entry>, mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
output_file: &Path, output_file: &Path,
all_finish_tx: tokio::sync::oneshot::Sender<()>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut num_completed = 0; let mut num_completed = 0;
let mut f = tokio::fs::File::create(output_file).await?; let mut f = tokio::fs::File::create(output_file).await?;
loop { loop {
let maybe_entry = tokio::select! { let entry = tokio::select! {
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
break; break;
} }
entry = entries_rx.recv() => { maybe_entry = entries_rx.recv() => {
entry match maybe_entry {
Some(entry) => entry,
None => {break;}
}
} }
};
let Some(entry) = maybe_entry else {
let _ = all_finish_tx.send(());
break;
}; };
let mut s = serde_json::to_string(&entry)?; let mut s = serde_json::to_string(&entry)?;
s.push('\n'); s.push('\n');
...@@ -278,7 +282,7 @@ async fn output_writer( ...@@ -278,7 +282,7 @@ async fn output_writer(
num_completed += 1; num_completed += 1;
// TODO: Progress bar. We'd have to count the lines in the input first, // TODO: Progress bar. We'd have to count the lines in the input first,
// and the input maybe be large // and the input maybe be large
tracing::info!("Saved {num_completed}"); tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
} }
Ok(()) Ok(())
} }
...@@ -39,7 +39,15 @@ use crate::protocols::openai::chat_completions::{ ...@@ -39,7 +39,15 @@ use crate::protocols::openai::chat_completions::{
}; };
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5; /// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;
/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
pub async fn make_engine( pub async fn make_engine(
gguf_path: &Path, gguf_path: &Path,
...@@ -110,10 +118,10 @@ impl MistralRsEngine { ...@@ -110,10 +118,10 @@ impl MistralRsEngine {
let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN; let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
// Paged attention requires cuda // Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") { let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
Some(PagedAttentionConfig::new( Some(PagedAttentionConfig::new(
None, // Block size, default 32 None, // Block size, default 32
512, // CPU memory in MiB 4096, // CPU memory in MiB
MemoryGpuConfig::ContextSize(max_seq_len), MemoryGpuConfig::ContextSize(max_seq_len),
)?) )?)
} else { } else {
...@@ -133,7 +141,7 @@ impl MistralRsEngine { ...@@ -133,7 +141,7 @@ impl MistralRsEngine {
None, None,
paged_attention_config, paged_attention_config,
)?; )?;
let scheduler = if cfg!(feature = "cuda") { let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
tracing::debug!("Using mistralrs PagedAttentionMeta scheduler"); tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() { let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
Some(conf) => conf.clone(), Some(conf) => conf.clone(),
...@@ -154,9 +162,12 @@ impl MistralRsEngine { ...@@ -154,9 +162,12 @@ impl MistralRsEngine {
}; };
// Create the MistralRs, which is a runner // Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16); let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
Ok(MistralRsEngine { let engine = MistralRsEngine {
mistralrs: builder.build(), mistralrs: builder.build(),
}) };
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
let _ = engine.mistralrs.next_request_id();
Ok(engine)
} }
} }
...@@ -231,13 +242,14 @@ impl ...@@ -231,13 +242,14 @@ impl
n_choices: 1, n_choices: 1,
dry_params: det.dry_params, dry_params: det.dry_params,
}; };
let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest { let mistralrs_request = Request::Normal(NormalRequest {
id: request_id,
messages: RequestMessage::Chat(messages), messages: RequestMessage::Chat(messages),
sampling_params, sampling_params,
response: tx, response: tx,
return_logprobs: request.inner.logprobs.unwrap_or_default(), return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true, is_streaming: true,
id: self.mistralrs.next_request_id(),
constraint: Constraint::None, constraint: Constraint::None,
suffix: None, suffix: None,
adapters: None, adapters: None,
...@@ -254,14 +266,14 @@ impl ...@@ -254,14 +266,14 @@ impl
let response = match response.as_result() { let response = match response.as_result() {
Ok(r) => r, Ok(r) => r,
Err(err) => { Err(err) => {
tracing::error!(%err, "Failed converting mistralrs channel response to result."); tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
break; break;
} }
}; };
match response { match response {
ResponseOk::Chunk(c) => { ResponseOk::Chunk(c) => {
let Some(from_assistant) = c.choices[0].delta.content.clone() else { let Some(from_assistant) = c.choices[0].delta.content.clone() else {
tracing::warn!("No content from mistralrs. Abandoning request."); tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
break; break;
}; };
let finish_reason = match &c.choices[0].finish_reason.as_deref() { let finish_reason = match &c.choices[0].finish_reason.as_deref() {
...@@ -272,7 +284,7 @@ impl ...@@ -272,7 +284,7 @@ impl
Some(FinishReason::Length) Some(FinishReason::Length)
} }
Some(s) => { Some(s) => {
tracing::warn!(stop_reason = s, "Unknow stop reason"); tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
Some(FinishReason::Stop) Some(FinishReason::Stop)
} }
None => None, None => None,
...@@ -312,11 +324,11 @@ impl ...@@ -312,11 +324,11 @@ impl
yield ann; yield ann;
if finish_reason.is_some() { if finish_reason.is_some() {
//tracing::trace!("Finish reason: {finish_reason:?}"); //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
break; break;
} }
}, },
x => tracing::error!("Unhandled. {x:?}"), x => tracing::error!(request_id, "Unhandled. {x:?}"),
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment