Unverified Commit f69580b0 authored by Oleg Zhelezniak's avatar Oleg Zhelezniak Committed by GitHub
Browse files

fix: tool call loss under speculative decoding (#7768)


Signed-off-by: default avatarjellysnack <oleg.jellysnack@gmail.com>
parent 1c199f88
...@@ -232,6 +232,10 @@ impl ChoiceJailState { ...@@ -232,6 +232,10 @@ impl ChoiceJailState {
// Handle trailing content if any // Handle trailing content if any
if !trailing_part.is_empty() { if !trailing_part.is_empty() {
if jail_stream.should_start_jail(trailing_part) {
self.is_jailed = true;
self.accumulated_content = trailing_part.to_string();
} else {
#[allow(deprecated)] #[allow(deprecated)]
let trailing_choice = create_choice_stream( let trailing_choice = create_choice_stream(
choice.index, choice.index,
...@@ -244,6 +248,7 @@ impl ChoiceJailState { ...@@ -244,6 +248,7 @@ impl ChoiceJailState {
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
} }
}
} else { } else {
// Start jailing with the marker and suffix // Start jailing with the marker and suffix
self.is_jailed = true; self.is_jailed = true;
...@@ -347,13 +352,21 @@ impl ChoiceJailState { ...@@ -347,13 +352,21 @@ impl ChoiceJailState {
emissions.push(ChoiceEmission::Content(unjailed_choice)); emissions.push(ChoiceEmission::Content(unjailed_choice));
} }
// End jailing before processing trailing content
let trailing_owned = trailing_part.to_string();
self.end_jail();
// Handle trailing content if any // Handle trailing content if any
if !trailing_part.is_empty() { if !trailing_owned.is_empty() {
if jail_stream.should_start_jail(&trailing_owned) {
self.is_jailed = true;
self.accumulated_content = trailing_owned;
} else {
#[allow(deprecated)] #[allow(deprecated)]
let trailing_choice = create_choice_stream( let trailing_choice = create_choice_stream(
choice.index, choice.index,
choice.delta.role, choice.delta.role,
trailing_part, &trailing_owned,
None, None,
choice.finish_reason, choice.finish_reason,
None, None,
...@@ -361,9 +374,7 @@ impl ChoiceJailState { ...@@ -361,9 +374,7 @@ impl ChoiceJailState {
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
} }
}
// End jailing
self.end_jail();
} }
// If not unjailing, don't emit anything (still accumulating) // If not unjailing, don't emit anything (still accumulating)
} }
...@@ -1311,3 +1322,197 @@ impl Default for JailedStreamBuilder { ...@@ -1311,3 +1322,197 @@ impl Default for JailedStreamBuilder {
Self::new() Self::new()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use dynamo_protocols::types::CreateChatCompletionStreamResponse;
use futures::stream;
/// Helper: build a single-choice stream chunk with text content
#[allow(deprecated)]
fn text_chunk(text: &str) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(dynamo_protocols::types::ChatCompletionMessageContent::Text(
text.to_string(),
)),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
};
Annotated {
data: Some(NvCreateChatCompletionStreamResponse {
inner: CreateChatCompletionStreamResponse {
id: "id-42".to_string(),
object: "chat.completion.chunk".to_string(),
created: 0,
model: "test-model".to_string(),
choices: vec![choice],
usage: None,
service_tier: None,
system_fingerprint: None,
},
nvext: None,
}),
id: None,
event: None,
comment: None,
error: None,
}
}
/// Collect all emitted tool calls from the jailed stream output
fn collect_tool_calls(
responses: &[Annotated<NvCreateChatCompletionStreamResponse>],
) -> Vec<(String, String)> {
let mut tool_calls = Vec::new();
for resp in responses {
if let Some(ref data) = resp.data {
for choice in &data.inner.choices {
if let Some(ref tcs) = choice.delta.tool_calls {
for tc in tcs {
if let Some(ref func) = tc.function {
let name = func.name.clone().unwrap_or_default();
let args = func.arguments.clone().unwrap_or_default();
tool_calls.push((name, args));
}
}
}
}
}
}
tool_calls
}
/// Collect all emitted text content from the jailed stream output
fn collect_text_content(
responses: &[Annotated<NvCreateChatCompletionStreamResponse>],
) -> String {
responses
.iter()
.flat_map(|r| r.data.iter())
.flat_map(|d| d.inner.choices.iter())
.filter_map(|c| {
if let Some(dynamo_protocols::types::ChatCompletionMessageContent::Text(t)) =
&c.delta.content
{
Some(t.as_str())
} else {
None
}
})
.collect()
}
#[tokio::test]
async fn test_multi_tool_call_single_chunk() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![text_chunk(
"<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"SF\"}}\n</tool_call>\n<tool_call>\n{\"name\": \"get_time\", \"arguments\": {\"timezone\": \"PST\"}}\n</tool_call>",
)];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert!(
tool_calls.len() >= 2,
"Expected at least 2 tool calls, got {}: {:?}",
tool_calls.len(),
tool_calls
);
let names: Vec<&str> = tool_calls.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"get_weather"),
"Missing get_weather tool call. Got: {:?}",
names
);
assert!(
names.contains(&"get_time"),
"Missing get_time tool call. Got: {:?}",
names
);
}
#[tokio::test]
async fn test_multi_tool_call_multiple_chunks() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![
text_chunk("<tool_call>\n{\"name\": \"get_weather\", \"arguments\""),
text_chunk(
": {\"location\": \"SF\"}}\n</tool_call>\n<tool_call>\n{\"name\": \"get_time\"",
),
text_chunk(", \"arguments\": {\"timezone\": \"PST\"}}\n</tool_call>"),
];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert!(
tool_calls.len() >= 2,
"Expected at least 2 tool calls, got {}: {:?}",
tool_calls.len(),
tool_calls
);
let names: Vec<&str> = tool_calls.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"get_weather"),
"Missing get_weather tool call. Got: {:?}",
names
);
assert!(
names.contains(&"get_time"),
"Missing get_time tool call. Got: {:?}",
names
);
}
#[tokio::test]
async fn test_trailing_text_not_re_jailed() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![text_chunk(
"<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"SF\"}}\n</tool_call>\nDone!",
)];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert_eq!(
tool_calls.len(),
1,
"Expected exactly 1 tool call, got {}: {:?}",
tool_calls.len(),
tool_calls
);
assert_eq!(tool_calls[0].0, "get_weather");
let all_text = collect_text_content(&responses);
assert!(
all_text.contains("Done!"),
"Trailing text 'Done!' should appear in output. Got text: {:?}",
all_text
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment