"docs/guides/vscode:/vscode.git/clone" did not exist on "1c77531a501dee670d4b1d71e198b857f451ca55"
Commit fd95f37b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(mistralrs): Disable paged attention (#234)

Under load it sometimes drops a request. The request gets added to the batch (sequence) and immediately gets a FinishReason Stop. Not sure why. It doesn't happen with the default scheduler (non-paged attention), so switch to that for now.
parent 48a59890
......@@ -14,6 +14,7 @@
// limitations under the License.
use anyhow::Context as _;
use async_openai::types::FinishReason;
use dynamo_llm::model_card::model::ModelDeploymentCard;
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::types::openai::chat_completions::{
......@@ -37,7 +38,6 @@ use crate::EngineConfig;
const MAX_TOKENS: u32 = 8192;
const OUTPUT_FILENAME: &str = "output.jsonl";
const DUMMY_MODEL_NAME: &str = "dynamo-run-batch";
#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
......@@ -54,6 +54,12 @@ struct Entry {
#[serde(default)]
elapsed_ms: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
finish_reason: Option<FinishReason>,
#[serde(skip, default)]
request_id: usize,
}
pub async fn run(
......@@ -71,29 +77,21 @@ pub async fn run(
);
}
let (_service_name, engine, _inspect_template) =
let (service_name, engine, _inspect_template) =
common::prepare_engine(runtime.clone(), engine_config).await?;
let service_name_ref = Arc::new(service_name);
let pre_processor = if let Some(card) = maybe_card {
Some(OpenAIPreprocessor::new(card).await?)
} else {
None
};
let (all_finish_tx, all_finish_rx) = tokio::sync::oneshot::channel();
let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
let dw_cancel_token = cancel_token.clone();
let mut output_file = input_jsonl.clone();
output_file.set_file_name(OUTPUT_FILENAME);
tokio::spawn(async move {
if let Err(err) = output_writer(
dw_cancel_token,
done_entries_rx,
&output_file,
all_finish_tx,
)
.await
{
if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
tracing::error!(%err, "Failed writing output to {}", output_file.display());
}
});
......@@ -125,15 +123,18 @@ pub async fn run(
anyhow::bail!("Error parsing entry: '{line}'. {err}");
}
};
entry.request_id = request_id;
let engine = engine.clone();
let pre_processor = pre_processor.clone();
let tokens_in = tokens_in.clone();
let tokens_out = tokens_out.clone();
let done_entries_tx = done_entries_tx.clone();
let service_name_ref = service_name_ref.clone();
let handle = tokio::spawn(async move {
let local_start = Instant::now();
let response = match evaluate(request_id, engine, &entry.text).await {
let response =
match evaluate(request_id, service_name_ref.as_str(), engine, &mut entry).await {
Ok(r) => r,
Err(err) => {
tracing::error!(%err, entry.text, "Failed evaluating prompt");
......@@ -175,8 +176,6 @@ pub async fn run(
}
_ = futures::future::join_all(handles) => {
}
_ = all_finish_rx => {
}
}
let elapsed = Instant::now() - start;
let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
......@@ -198,23 +197,24 @@ pub async fn run(
// Run a single prompt through the engine
async fn evaluate(
_request_id: usize,
request_id: usize,
service_name: &str,
engine: OpenAIChatCompletionsStreamingEngine,
prompt: &str,
entry: &mut Entry,
) -> anyhow::Result<String> {
let user_message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
prompt.to_string(),
entry.text.clone(),
),
name: None,
},
);
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.messages(vec![user_message])
.model(DUMMY_MODEL_NAME)
.model(service_name)
.stream(true)
.max_tokens(MAX_TOKENS)
.max_completion_tokens(MAX_TOKENS)
.build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None };
let mut stream = engine.generate(Context::new(req)).await?;
......@@ -223,24 +223,30 @@ async fn evaluate(
match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => {
// Normal case
let entry = data.inner.choices.first();
let chat_comp = entry.as_ref().unwrap();
let choice = data.inner.choices.first();
let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
output += c;
}
entry.finish_reason = chat_comp.finish_reason;
if chat_comp.finish_reason.is_some() {
tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
tracing::trace!(
request_id,
"finish reason: {:?}",
chat_comp.finish_reason.unwrap()
);
break;
}
}
(None, Some("error")) => {
tracing::error!(request_id, "the error case");
// There's only one error but we loop in case that changes
for err in item.comment.unwrap_or_default() {
tracing::error!("Engine error: {err}");
tracing::error!(request_id, "Engine error: {err}");
}
}
(None, Some(annotation)) => {
tracing::debug!("Annotation. {annotation}: {:?}", item.comment);
tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
}
_ => {
unreachable!("Event from engine with no data, no error, no annotation.");
......@@ -254,22 +260,20 @@ async fn output_writer(
cancel_token: CancellationToken,
mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
output_file: &Path,
all_finish_tx: tokio::sync::oneshot::Sender<()>,
) -> anyhow::Result<()> {
let mut num_completed = 0;
let mut f = tokio::fs::File::create(output_file).await?;
loop {
let maybe_entry = tokio::select! {
let entry = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
entry = entries_rx.recv() => {
entry
maybe_entry = entries_rx.recv() => {
match maybe_entry {
Some(entry) => entry,
None => {break;}
}
}
};
let Some(entry) = maybe_entry else {
let _ = all_finish_tx.send(());
break;
};
let mut s = serde_json::to_string(&entry)?;
s.push('\n');
......@@ -278,7 +282,7 @@ async fn output_writer(
num_completed += 1;
// TODO: Progress bar. We'd have to count the lines in the input first,
// and the input maybe be large
tracing::info!("Saved {num_completed}");
tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
}
Ok(())
}
......@@ -39,7 +39,15 @@ use crate::protocols::openai::chat_completions::{
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;
/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
pub async fn make_engine(
gguf_path: &Path,
......@@ -110,10 +118,10 @@ impl MistralRsEngine {
let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") {
let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
Some(PagedAttentionConfig::new(
None, // Block size, default 32
512, // CPU memory in MiB
4096, // CPU memory in MiB
MemoryGpuConfig::ContextSize(max_seq_len),
)?)
} else {
......@@ -133,7 +141,7 @@ impl MistralRsEngine {
None,
paged_attention_config,
)?;
let scheduler = if cfg!(feature = "cuda") {
let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
Some(conf) => conf.clone(),
......@@ -154,9 +162,12 @@ impl MistralRsEngine {
};
// Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
Ok(MistralRsEngine {
let engine = MistralRsEngine {
mistralrs: builder.build(),
})
};
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
let _ = engine.mistralrs.next_request_id();
Ok(engine)
}
}
......@@ -231,13 +242,14 @@ impl
n_choices: 1,
dry_params: det.dry_params,
};
let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest {
id: request_id,
messages: RequestMessage::Chat(messages),
sampling_params,
response: tx,
return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true,
id: self.mistralrs.next_request_id(),
constraint: Constraint::None,
suffix: None,
adapters: None,
......@@ -254,14 +266,14 @@ impl
let response = match response.as_result() {
Ok(r) => r,
Err(err) => {
tracing::error!(%err, "Failed converting mistralrs channel response to result.");
tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
break;
}
};
match response {
ResponseOk::Chunk(c) => {
let Some(from_assistant) = c.choices[0].delta.content.clone() else {
tracing::warn!("No content from mistralrs. Abandoning request.");
tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
break;
};
let finish_reason = match &c.choices[0].finish_reason.as_deref() {
......@@ -272,7 +284,7 @@ impl
Some(FinishReason::Length)
}
Some(s) => {
tracing::warn!(stop_reason = s, "Unknow stop reason");
tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
Some(FinishReason::Stop)
}
None => None,
......@@ -312,11 +324,11 @@ impl
yield ann;
if finish_reason.is_some() {
//tracing::trace!("Finish reason: {finish_reason:?}");
//tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
break;
}
},
x => tracing::error!("Unhandled. {x:?}"),
x => tracing::error!(request_id, "Unhandled. {x:?}"),
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment