Unverified Commit 59d18d8e authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

fix(jail): preserve logprobs through tool-call jailing (#8072)


Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent b0f7d8a5
...@@ -100,6 +100,10 @@ struct ChoiceJailState { ...@@ -100,6 +100,10 @@ struct ChoiceJailState {
is_jailed: bool, is_jailed: bool,
/// Accumulated content for this choice while jailed /// Accumulated content for this choice while jailed
accumulated_content: String, accumulated_content: String,
/// Accumulated logprobs for this choice while jailed.
/// Logprobs from each jailed chunk are appended so the full token-level
/// log-probability information is preserved when the jail emits.
accumulated_logprobs: Option<ChatChoiceLogprobs>,
/// Buffer for partial marker matches across chunks /// Buffer for partial marker matches across chunks
partial_match_buffer: String, partial_match_buffer: String,
/// Stream finish reason /// Stream finish reason
...@@ -145,6 +149,7 @@ impl ChoiceJailState { ...@@ -145,6 +149,7 @@ impl ChoiceJailState {
index, index,
is_jailed: starts_jailed, is_jailed: starts_jailed,
accumulated_content: String::new(), accumulated_content: String::new(),
accumulated_logprobs: None,
partial_match_buffer: String::new(), partial_match_buffer: String::new(),
stream_finish_reason: None, stream_finish_reason: None,
emitted_tool_calls_count: 0, emitted_tool_calls_count: 0,
...@@ -152,16 +157,41 @@ impl ChoiceJailState { ...@@ -152,16 +157,41 @@ impl ChoiceJailState {
} }
} }
/// Add content to this choice's accumulation /// Add content and logprobs to this choice's accumulation
fn accumulate(&mut self, content: &str) { fn accumulate(&mut self, content: &str, logprobs: Option<&ChatChoiceLogprobs>) {
if self.is_jailed { if self.is_jailed {
self.accumulated_content.push_str(content); self.accumulated_content.push_str(content);
// Accumulate logprobs so they are preserved across jailed chunks.
if let Some(lp) = logprobs {
let state_lps = self.accumulated_logprobs.get_or_insert(ChatChoiceLogprobs {
content: None,
refusal: None,
});
if let Some(content_lps) = &lp.content {
state_lps
.content
.get_or_insert_with(Vec::new)
.extend(content_lps.clone());
}
if let Some(refusal_lps) = &lp.refusal {
state_lps
.refusal
.get_or_insert_with(Vec::new)
.extend(refusal_lps.clone());
}
}
} }
} }
/// Consume the accumulated logprobs, replacing them with `None`.
fn take_accumulated_logprobs(&mut self) -> Option<ChatChoiceLogprobs> {
self.accumulated_logprobs.take()
}
/// End jailing and return the accumulated content /// End jailing and return the accumulated content
fn end_jail(&mut self) -> String { fn end_jail(&mut self) -> String {
self.is_jailed = false; self.is_jailed = false;
self.accumulated_logprobs = None;
std::mem::take(&mut self.accumulated_content) std::mem::take(&mut self.accumulated_content)
} }
...@@ -235,6 +265,8 @@ impl ChoiceJailState { ...@@ -235,6 +265,8 @@ impl ChoiceJailState {
if jail_stream.should_start_jail(trailing_part) { if jail_stream.should_start_jail(trailing_part) {
self.is_jailed = true; self.is_jailed = true;
self.accumulated_content = trailing_part.to_string(); self.accumulated_content = trailing_part.to_string();
// No logprobs to seed here — they were already emitted with the tool call
self.accumulated_logprobs = None;
} else { } else {
#[allow(deprecated)] #[allow(deprecated)]
let trailing_choice = create_choice_stream( let trailing_choice = create_choice_stream(
...@@ -253,6 +285,8 @@ impl ChoiceJailState { ...@@ -253,6 +285,8 @@ impl ChoiceJailState {
// Start jailing with the marker and suffix // Start jailing with the marker and suffix
self.is_jailed = true; self.is_jailed = true;
self.accumulated_content = full_content; self.accumulated_content = full_content;
// Seed accumulated logprobs with this chunk's logprobs
self.accumulated_logprobs = choice.logprobs.clone();
} }
self.partial_match_buffer.clear(); self.partial_match_buffer.clear();
...@@ -301,6 +335,8 @@ impl ChoiceJailState { ...@@ -301,6 +335,8 @@ impl ChoiceJailState {
// Start jailing with the combined content // Start jailing with the combined content
self.is_jailed = true; self.is_jailed = true;
self.accumulated_content = combined_content; self.accumulated_content = combined_content;
// Seed accumulated logprobs with this chunk's logprobs
self.accumulated_logprobs = choice.logprobs.clone();
self.partial_match_buffer.clear(); self.partial_match_buffer.clear();
} else { } else {
// No markers - emit everything // No markers - emit everything
...@@ -322,25 +358,31 @@ impl ChoiceJailState { ...@@ -322,25 +358,31 @@ impl ChoiceJailState {
} }
} }
} else { } else {
// Already jailed - accumulate and check for unjail // Already jailed - accumulate content AND logprobs, then check for unjail
self.accumulate(content); self.accumulate(content, choice.logprobs.as_ref());
let (should_end, split_pos) = let (should_end, split_pos) =
jail_stream.should_end_jail(&self.accumulated_content).await; jail_stream.should_end_jail(&self.accumulated_content).await;
if should_end { if should_end {
// Take accumulated logprobs before borrowing accumulated_content
let jail_logprobs = self.take_accumulated_logprobs();
// Split the content // Split the content
let (jailed_part, trailing_part) = self.accumulated_content.split_at(split_pos); let (jailed_part, trailing_part) = self.accumulated_content.split_at(split_pos);
let trailing_owned = trailing_part.to_string();
let jailed_owned = jailed_part.to_string();
// Create the unjailed choice // Create the unjailed choice, using accumulated logprobs
let unjailed_choice = jail_stream let mut unjailed_choice = jail_stream
.create_tool_call_choice( .create_tool_call_choice(
choice.index, choice.index,
jailed_part, &jailed_owned,
choice, choice,
self.emitted_tool_calls_count, self.emitted_tool_calls_count,
) )
.await; .await;
unjailed_choice.logprobs = jail_logprobs;
// Determine emission type based on whether tool calls were parsed // Determine emission type based on whether tool calls were parsed
if unjailed_choice.delta.tool_calls.is_some() { if unjailed_choice.delta.tool_calls.is_some() {
...@@ -353,7 +395,6 @@ impl ChoiceJailState { ...@@ -353,7 +395,6 @@ impl ChoiceJailState {
} }
// End jailing before processing trailing content // End jailing before processing trailing content
let trailing_owned = trailing_part.to_string();
self.end_jail(); self.end_jail();
// Handle trailing content if any // Handle trailing content if any
...@@ -393,7 +434,7 @@ impl ChoiceJailState { ...@@ -393,7 +434,7 @@ impl ChoiceJailState {
None, None,
self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost
None, None,
None, self.accumulated_logprobs.clone(),
); );
let mut final_choice = jail_stream let mut final_choice = jail_stream
...@@ -404,6 +445,8 @@ impl ChoiceJailState { ...@@ -404,6 +445,8 @@ impl ChoiceJailState {
self.emitted_tool_calls_count, self.emitted_tool_calls_count,
) )
.await; .await;
// Attach the full accumulated logprobs to the final choice
final_choice.logprobs = self.take_accumulated_logprobs();
// Preserve any pending reasoning content collected while jailed. // Preserve any pending reasoning content collected while jailed.
if let Some(pending_reasoning) = self.pending_reasoning_content.take() { if let Some(pending_reasoning) = self.pending_reasoning_content.take() {
...@@ -928,7 +971,7 @@ impl JailedStream { ...@@ -928,7 +971,7 @@ impl JailedStream {
Some(tool_call_chunks), Some(tool_call_chunks),
None, None,
None, None,
None, base_choice.logprobs.clone(),
); );
return choice; return choice;
} }
...@@ -1413,6 +1456,164 @@ mod tests { ...@@ -1413,6 +1456,164 @@ mod tests {
.collect() .collect()
} }
/// Helper: build a single-choice stream chunk with text content and logprobs
fn text_chunk_with_logprobs(text: &str) -> Annotated<NvCreateChatCompletionStreamResponse> {
let logprobs = ChatChoiceLogprobs {
content: Some(
text.chars()
.enumerate()
.map(
|(i, c)| dynamo_protocols::types::ChatCompletionTokenLogprob {
token: c.to_string(),
logprob: -(i as f32 + 1.0) * 0.1,
bytes: Some(c.to_string().into_bytes()),
top_logprobs: vec![],
},
)
.collect(),
),
refusal: None,
};
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(dynamo_protocols::types::ChatCompletionMessageContent::Text(
text.to_string(),
)),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: Some(logprobs),
};
Annotated {
data: Some(NvCreateChatCompletionStreamResponse {
inner: CreateChatCompletionStreamResponse {
id: "id-42".to_string(),
object: "chat.completion.chunk".to_string(),
created: 0,
model: "test-model".to_string(),
choices: vec![choice],
usage: None,
service_tier: None,
system_fingerprint: None,
},
nvext: None,
}),
id: None,
event: None,
comment: None,
error: None,
}
}
/// Collect all logprobs from jailed stream output choices
fn collect_logprobs(
responses: &[Annotated<NvCreateChatCompletionStreamResponse>],
) -> Vec<Option<ChatChoiceLogprobs>> {
responses
.iter()
.flat_map(|r| r.data.iter())
.flat_map(|d| d.inner.choices.iter())
.map(|c| c.logprobs.clone())
.collect()
}
#[tokio::test]
async fn test_tool_call_preserves_logprobs_single_chunk() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![text_chunk_with_logprobs(
"<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"SF\"}}\n</tool_call>",
)];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert_eq!(
tool_calls.len(),
1,
"Expected 1 tool call, got {:?}",
tool_calls
);
assert_eq!(tool_calls[0].0, "get_weather");
// Logprobs must be preserved even though the entire output is a tool call
let all_logprobs = collect_logprobs(&responses);
let has_some_logprobs = all_logprobs.iter().any(|lp| lp.is_some());
assert!(
has_some_logprobs,
"Logprobs should be preserved for tool call responses, got all None: {:?}",
all_logprobs
);
}
#[tokio::test]
async fn test_tool_call_preserves_logprobs_multiple_chunks() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![
text_chunk_with_logprobs("<tool_call>\n{\"name\": \"get_weather\", \"arguments\""),
text_chunk_with_logprobs(": {\"location\": \"SF\"}}\n</tool_call>"),
];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert!(!tool_calls.is_empty(), "Expected tool calls, got none");
let all_logprobs = collect_logprobs(&responses);
let has_some_logprobs = all_logprobs.iter().any(|lp| lp.is_some());
assert!(
has_some_logprobs,
"Logprobs should be preserved for tool call responses across chunks, got all None",
);
}
#[tokio::test]
async fn test_tool_call_with_text_preserves_logprobs() {
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let chunks = vec![text_chunk_with_logprobs(
"Let me check.\n<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"SF\"}}\n</tool_call>",
)];
let input_stream = Box::pin(stream::iter(chunks));
let output_stream = jail.apply_with_finish_reason(input_stream);
let responses: Vec<_> = output_stream.collect().await;
let tool_calls = collect_tool_calls(&responses);
assert_eq!(tool_calls.len(), 1);
let all_logprobs = collect_logprobs(&responses);
let has_some_logprobs = all_logprobs.iter().any(|lp| lp.is_some());
assert!(
has_some_logprobs,
"Logprobs should be preserved for mixed text+tool_call responses",
);
// Verify the logprobs content is non-empty
let logprob_entries: Vec<_> = all_logprobs
.iter()
.filter_map(|lp| lp.as_ref())
.filter_map(|lp| lp.content.as_ref())
.collect();
assert!(
logprob_entries.iter().any(|entries| !entries.is_empty()),
"Logprobs content should have entries",
);
}
#[tokio::test] #[tokio::test]
async fn test_multi_tool_call_single_chunk() { async fn test_multi_tool_call_single_chunk() {
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment