Unverified Commit 8e4ae22b authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

fix: aggregate logprobs (#2928)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent fae35432
......@@ -129,7 +129,7 @@ impl DeltaAggregator {
text: "".to_string(),
role: choice.delta.role,
finish_reason: None,
logprobs: choice.logprobs,
logprobs: None,
tool_calls: None,
reasoning_content: None,
});
......@@ -150,6 +150,28 @@ impl DeltaAggregator {
if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason);
}
// Update logprobs
if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert(
dynamo_async_openai::types::ChatChoiceLogprobs {
content: None,
refusal: None,
},
);
if let Some(content_lps) = &logprobs.content {
state_lps
.content
.get_or_insert(Vec::new())
.extend(content_lps.clone());
}
if let Some(refusal_lps) = &logprobs.refusal {
state_lps
.refusal
.get_or_insert(Vec::new())
.extend(refusal_lps.clone());
}
}
}
}
aggregator
......@@ -305,6 +327,7 @@ mod tests {
text: &str,
role: Option<dynamo_async_openai::types::Role>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprob: Option<f32>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
// ALLOW: function_call is deprecated
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
......@@ -315,11 +338,22 @@ mod tests {
refusal: None,
reasoning_content: None,
};
let logprobs = logprob.map(|lp| dynamo_async_openai::types::ChatChoiceLogprobs {
content: Some(vec![
dynamo_async_openai::types::ChatCompletionTokenLogprob {
token: text.to_string(),
logprob: lp,
bytes: None,
top_logprobs: vec![],
},
]),
refusal: None,
});
let choice = dynamo_async_openai::types::ChatChoiceStream {
index,
delta,
finish_reason,
logprobs: None,
logprobs,
};
let data = NvCreateChatCompletionStreamResponse {
......@@ -372,6 +406,7 @@ mod tests {
"Hello,",
Some(dynamo_async_openai::types::Role::User),
None,
None,
);
// Create a stream
......@@ -409,12 +444,14 @@ mod tests {
"Hello,",
Some(dynamo_async_openai::types::Role::User),
None,
Some(-0.1),
);
let annotated_delta2 = create_test_delta(
0,
" world!",
None,
Some(dynamo_async_openai::types::FinishReason::Stop),
Some(-0.2),
);
// Create a stream
......@@ -438,6 +475,25 @@ mod tests {
Some(dynamo_async_openai::types::FinishReason::Stop)
);
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
assert_eq!(
choice
.logprobs
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.len(),
2
);
assert_eq!(
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[0].logprob,
-0.1
);
assert_eq!(
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[1].logprob,
-0.2
);
}
#[allow(deprecated)]
......@@ -538,6 +594,7 @@ mod tests {
tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None,
);
let data = annotated_delta.data.unwrap();
......@@ -598,6 +655,7 @@ mod tests {
tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None,
);
let data = annotated_delta.data.unwrap();
......
......@@ -95,7 +95,7 @@ impl DeltaAggregator {
index: choice.index,
text: "".to_string(),
finish_reason: None,
logprobs: choice.logprobs,
logprobs: None,
});
state_choice.text.push_str(&choice.text);
......@@ -115,6 +115,24 @@ impl DeltaAggregator {
) => Some(FinishReason::ContentFilter),
None => None,
};
// Update logprobs
if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert(
dynamo_async_openai::types::Logprobs {
tokens: Vec::new(),
token_logprobs: Vec::new(),
top_logprobs: Vec::new(),
text_offset: Vec::new(),
},
);
state_lps.tokens.extend(logprobs.tokens.clone());
state_lps
.token_logprobs
.extend(logprobs.token_logprobs.clone());
state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
state_lps.text_offset.extend(logprobs.text_offset.clone());
}
}
}
aggregator
......@@ -196,6 +214,7 @@ mod tests {
index: u32,
text: &str,
finish_reason: Option<String>,
logprob: Option<f32>,
) -> Annotated<NvCreateCompletionResponse> {
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
......@@ -204,6 +223,20 @@ mod tests {
.and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into);
let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
tokens: vec![text.to_string()],
token_logprobs: vec![Some(lp)],
top_logprobs: vec![
serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
token: text.to_string(),
logprob: lp,
bytes: None,
})
.unwrap(),
],
text_offset: vec![0],
});
let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(),
......@@ -214,7 +247,7 @@ mod tests {
index,
text: text.to_string(),
finish_reason,
logprobs: None,
logprobs,
}],
object: "text_completion".to_string(),
};
......@@ -253,7 +286,7 @@ mod tests {
#[tokio::test]
async fn test_single_delta() {
// Create a sample delta
let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));
let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
// Create a stream
let stream = Box::pin(stream::iter(vec![annotated_delta]));
......@@ -291,8 +324,9 @@ mod tests {
// Create multiple deltas with the same choice index
// One will have a MessageRole and no FinishReason,
// the other will have a FinishReason and no MessageRole
let annotated_delta1 = create_test_delta(0, "Hello,", None);
let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));
let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
let annotated_delta2 =
create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
// Create a stream
let annotated_deltas = vec![annotated_delta1, annotated_delta2];
......@@ -314,9 +348,10 @@ mod tests {
choice.finish_reason,
Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
);
assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
choice.logprobs.as_ref().unwrap().token_logprobs,
vec![Some(-0.1), Some(-0.2)]
);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment