Unverified Commit 3fd0ab3d authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

fix: multi-turn bug in should_add_generation_prompt (#4168)

parent 7750ed1a
...@@ -163,14 +163,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -163,14 +163,17 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
} }
fn should_add_generation_prompt(&self) -> bool { fn should_add_generation_prompt(&self) -> bool {
if let Some(last) = self.inner.messages.last() { // Only add generation prompt if the last message was not assistant (default to true when no last message)
matches!( self.inner
last, .messages
dynamo_async_openai::types::ChatCompletionRequestMessage::User(_) .last()
) .map(|last| {
} else { !matches!(
true last,
} dynamo_async_openai::types::ChatCompletionRequestMessage::Assistant(_)
)
})
.unwrap_or(true)
} }
fn extract_text(&self) -> Option<TextInput> { fn extract_text(&self) -> Option<TextInput> {
...@@ -294,6 +297,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -294,6 +297,7 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use dynamo_async_openai::types::ChatCompletionRequestMessage as Msg;
#[test] #[test]
fn test_may_be_fix_tool_schema_missing_type_and_properties() { fn test_may_be_fix_tool_schema_missing_type_and_properties() {
...@@ -700,4 +704,46 @@ NORMAL MODE ...@@ -700,4 +704,46 @@ NORMAL MODE
assert!(messages[0]["content"].is_array()); assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3); assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
} }
fn user() -> Msg {
Msg::User(Default::default())
}
fn asst() -> Msg {
Msg::Assistant(Default::default())
}
fn tool() -> Msg {
Msg::Tool(Default::default())
}
fn dummy_state(messages: Vec<Msg>) -> NvCreateChatCompletionRequest {
let json = serde_json::json!({
"model": "test-model",
"messages": messages
});
serde_json::from_value(json).unwrap()
}
#[test]
fn add_after_user() {
let s = dummy_state(vec![user()]);
assert!(s.should_add_generation_prompt());
}
#[test]
fn add_after_tool() {
let s = dummy_state(vec![tool()]);
assert!(s.should_add_generation_prompt());
}
#[test]
fn no_after_assistant() {
let s = dummy_state(vec![asst()]);
assert!(!s.should_add_generation_prompt());
}
#[test]
fn add_when_empty() {
let s = dummy_state(vec![]);
assert!(s.should_add_generation_prompt());
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment