Unverified Commit 8e4ae22b authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

fix: aggregate logprobs (#2928)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent fae35432
...@@ -129,7 +129,7 @@ impl DeltaAggregator { ...@@ -129,7 +129,7 @@ impl DeltaAggregator {
text: "".to_string(), text: "".to_string(),
role: choice.delta.role, role: choice.delta.role,
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: None,
tool_calls: None, tool_calls: None,
reasoning_content: None, reasoning_content: None,
}); });
...@@ -150,6 +150,28 @@ impl DeltaAggregator { ...@@ -150,6 +150,28 @@ impl DeltaAggregator {
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason); state_choice.finish_reason = Some(finish_reason);
} }
// Update logprobs
if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert(
dynamo_async_openai::types::ChatChoiceLogprobs {
content: None,
refusal: None,
},
);
if let Some(content_lps) = &logprobs.content {
state_lps
.content
.get_or_insert(Vec::new())
.extend(content_lps.clone());
}
if let Some(refusal_lps) = &logprobs.refusal {
state_lps
.refusal
.get_or_insert(Vec::new())
.extend(refusal_lps.clone());
}
}
} }
} }
aggregator aggregator
...@@ -305,6 +327,7 @@ mod tests { ...@@ -305,6 +327,7 @@ mod tests {
text: &str, text: &str,
role: Option<dynamo_async_openai::types::Role>, role: Option<dynamo_async_openai::types::Role>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprob: Option<f32>,
) -> Annotated<NvCreateChatCompletionStreamResponse> { ) -> Annotated<NvCreateChatCompletionStreamResponse> {
// ALLOW: function_call is deprecated // ALLOW: function_call is deprecated
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
...@@ -315,11 +338,22 @@ mod tests { ...@@ -315,11 +338,22 @@ mod tests {
refusal: None, refusal: None,
reasoning_content: None, reasoning_content: None,
}; };
let logprobs = logprob.map(|lp| dynamo_async_openai::types::ChatChoiceLogprobs {
content: Some(vec![
dynamo_async_openai::types::ChatCompletionTokenLogprob {
token: text.to_string(),
logprob: lp,
bytes: None,
top_logprobs: vec![],
},
]),
refusal: None,
});
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
index, index,
delta, delta,
finish_reason, finish_reason,
logprobs: None, logprobs,
}; };
let data = NvCreateChatCompletionStreamResponse { let data = NvCreateChatCompletionStreamResponse {
...@@ -372,6 +406,7 @@ mod tests { ...@@ -372,6 +406,7 @@ mod tests {
"Hello,", "Hello,",
Some(dynamo_async_openai::types::Role::User), Some(dynamo_async_openai::types::Role::User),
None, None,
None,
); );
// Create a stream // Create a stream
...@@ -409,12 +444,14 @@ mod tests { ...@@ -409,12 +444,14 @@ mod tests {
"Hello,", "Hello,",
Some(dynamo_async_openai::types::Role::User), Some(dynamo_async_openai::types::Role::User),
None, None,
Some(-0.1),
); );
let annotated_delta2 = create_test_delta( let annotated_delta2 = create_test_delta(
0, 0,
" world!", " world!",
None, None,
Some(dynamo_async_openai::types::FinishReason::Stop), Some(dynamo_async_openai::types::FinishReason::Stop),
Some(-0.2),
); );
// Create a stream // Create a stream
...@@ -438,6 +475,25 @@ mod tests { ...@@ -438,6 +475,25 @@ mod tests {
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
); );
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User); assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
assert_eq!(
choice
.logprobs
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.len(),
2
);
assert_eq!(
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[0].logprob,
-0.1
);
assert_eq!(
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[1].logprob,
-0.2
);
} }
#[allow(deprecated)] #[allow(deprecated)]
...@@ -538,6 +594,7 @@ mod tests { ...@@ -538,6 +594,7 @@ mod tests {
tool_call_json, tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant), Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls), Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None,
); );
let data = annotated_delta.data.unwrap(); let data = annotated_delta.data.unwrap();
...@@ -598,6 +655,7 @@ mod tests { ...@@ -598,6 +655,7 @@ mod tests {
tool_call_json, tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant), Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls), Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None,
); );
let data = annotated_delta.data.unwrap(); let data = annotated_delta.data.unwrap();
......
...@@ -95,7 +95,7 @@ impl DeltaAggregator { ...@@ -95,7 +95,7 @@ impl DeltaAggregator {
index: choice.index, index: choice.index,
text: "".to_string(), text: "".to_string(),
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: None,
}); });
state_choice.text.push_str(&choice.text); state_choice.text.push_str(&choice.text);
...@@ -115,6 +115,24 @@ impl DeltaAggregator { ...@@ -115,6 +115,24 @@ impl DeltaAggregator {
) => Some(FinishReason::ContentFilter), ) => Some(FinishReason::ContentFilter),
None => None, None => None,
}; };
// Update logprobs
if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert(
dynamo_async_openai::types::Logprobs {
tokens: Vec::new(),
token_logprobs: Vec::new(),
top_logprobs: Vec::new(),
text_offset: Vec::new(),
},
);
state_lps.tokens.extend(logprobs.tokens.clone());
state_lps
.token_logprobs
.extend(logprobs.token_logprobs.clone());
state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
state_lps.text_offset.extend(logprobs.text_offset.clone());
}
} }
} }
aggregator aggregator
...@@ -196,6 +214,7 @@ mod tests { ...@@ -196,6 +214,7 @@ mod tests {
index: u32, index: u32,
text: &str, text: &str,
finish_reason: Option<String>, finish_reason: Option<String>,
logprob: Option<f32>,
) -> Annotated<NvCreateCompletionResponse> { ) -> Annotated<NvCreateCompletionResponse> {
// This will silently discard invalid_finish reason values and fall back // This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code // to None - totally fine since this is test code
...@@ -204,6 +223,20 @@ mod tests { ...@@ -204,6 +223,20 @@ mod tests {
.and_then(|s| FinishReason::from_str(s).ok()) .and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into); .map(Into::into);
let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
tokens: vec![text.to_string()],
token_logprobs: vec![Some(lp)],
top_logprobs: vec![
serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
token: text.to_string(),
logprob: lp,
bytes: None,
})
.unwrap(),
],
text_offset: vec![0],
});
let inner = dynamo_async_openai::types::CreateCompletionResponse { let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(), model: "meta/llama-3.1-8b".to_string(),
...@@ -214,7 +247,7 @@ mod tests { ...@@ -214,7 +247,7 @@ mod tests {
index, index,
text: text.to_string(), text: text.to_string(),
finish_reason, finish_reason,
logprobs: None, logprobs,
}], }],
object: "text_completion".to_string(), object: "text_completion".to_string(),
}; };
...@@ -253,7 +286,7 @@ mod tests { ...@@ -253,7 +286,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_single_delta() { async fn test_single_delta() {
// Create a sample delta // Create a sample delta
let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string())); let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
// Create a stream // Create a stream
let stream = Box::pin(stream::iter(vec![annotated_delta])); let stream = Box::pin(stream::iter(vec![annotated_delta]));
...@@ -291,8 +324,9 @@ mod tests { ...@@ -291,8 +324,9 @@ mod tests {
// Create multiple deltas with the same choice index // Create multiple deltas with the same choice index
// One will have a MessageRole and no FinishReason, // One will have a MessageRole and no FinishReason,
// the other will have a FinishReason and no MessageRole // the other will have a FinishReason and no MessageRole
let annotated_delta1 = create_test_delta(0, "Hello,", None); let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string())); let annotated_delta2 =
create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
// Create a stream // Create a stream
let annotated_deltas = vec![annotated_delta1, annotated_delta2]; let annotated_deltas = vec![annotated_delta1, annotated_delta2];
...@@ -314,9 +348,10 @@ mod tests { ...@@ -314,9 +348,10 @@ mod tests {
choice.finish_reason, choice.finish_reason,
Some(dynamo_async_openai::types::CompletionFinishReason::Stop) Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
); );
assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
assert_eq!( assert_eq!(
choice.finish_reason, choice.logprobs.as_ref().unwrap().token_logprobs,
Some(dynamo_async_openai::types::CompletionFinishReason::Stop) vec![Some(-0.1), Some(-0.2)]
); );
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment