Unverified Commit 947939c7 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: populate logprobs bytes and token fields in OpenAI-compatible responses (#6953)

parent 1d509252
...@@ -1107,7 +1107,7 @@ class BaseWorkerHandler(ABC): ...@@ -1107,7 +1107,7 @@ class BaseWorkerHandler(ABC):
@staticmethod @staticmethod
def _extract_logprobs( def _extract_logprobs(
output, num_output_tokens_so_far: int output, num_output_tokens_so_far: int, tokenizer=None
) -> tuple[list[float] | None, list[list[dict]] | None]: ) -> tuple[list[float] | None, list[list[dict]] | None]:
""" """
Extract logprobs from vLLM CompletionOutput for new tokens. Extract logprobs from vLLM CompletionOutput for new tokens.
...@@ -1115,6 +1115,8 @@ class BaseWorkerHandler(ABC): ...@@ -1115,6 +1115,8 @@ class BaseWorkerHandler(ABC):
Args: Args:
output: vLLM CompletionOutput object output: vLLM CompletionOutput object
num_output_tokens_so_far: Number of tokens already processed num_output_tokens_so_far: Number of tokens already processed
tokenizer: Optional tokenizer for decoding token IDs when
decoded_token is not populated by the engine
Returns: Returns:
Tuple of (log_probs, top_logprobs) in Dynamo's expected format: Tuple of (log_probs, top_logprobs) in Dynamo's expected format:
...@@ -1147,18 +1149,23 @@ class BaseWorkerHandler(ABC): ...@@ -1147,18 +1149,23 @@ class BaseWorkerHandler(ABC):
# Build top_logprobs list for this token position # Build top_logprobs list for this token position
token_top_logprobs = [] token_top_logprobs = []
for tok_id, logprob_info in token_logprobs_dict.items(): for tok_id, logprob_info in token_logprobs_dict.items():
token_str = getattr(logprob_info, "decoded_token", None)
if not token_str and tokenizer:
try:
token_str = tokenizer.decode([tok_id])
except Exception:
token_str = None
token_top_logprobs.append( token_top_logprobs.append(
{ {
"rank": ( "rank": (
logprob_info.rank if hasattr(logprob_info, "rank") else 0 logprob_info.rank if hasattr(logprob_info, "rank") else 0
), ),
"token_id": tok_id, "token_id": tok_id,
"token": ( "token": token_str,
logprob_info.decoded_token
if hasattr(logprob_info, "decoded_token")
else None
),
"logprob": float(logprob_info.logprob), "logprob": float(logprob_info.logprob),
"bytes": (
list(token_str.encode("utf-8")) if token_str else None
),
} }
) )
top_logprobs.append(token_top_logprobs) top_logprobs.append(token_top_logprobs)
...@@ -1250,8 +1257,9 @@ class BaseWorkerHandler(ABC): ...@@ -1250,8 +1257,9 @@ class BaseWorkerHandler(ABC):
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available # Extract logprobs for new tokens if available
tokenizer = getattr(self.engine_client, "tokenizer", None)
log_probs, top_logprobs = self._extract_logprobs( log_probs, top_logprobs = self._extract_logprobs(
output, num_output_tokens_so_far output, num_output_tokens_so_far, tokenizer=tokenizer
) )
if log_probs is not None: if log_probs is not None:
out["log_probs"] = log_probs out["log_probs"] = log_probs
......
...@@ -59,6 +59,8 @@ pub struct TopLogprob { ...@@ -59,6 +59,8 @@ pub struct TopLogprob {
pub token_id: TokenIdType, pub token_id: TokenIdType,
pub token: TokenType, pub token: TokenType,
pub logprob: f64, pub logprob: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
} }
pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs
......
...@@ -9,6 +9,7 @@ use super::{ ...@@ -9,6 +9,7 @@ use super::{
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
}; };
use crate::protocols::openai::common_ext::CommonExtProvider; use crate::protocols::openai::common_ext::CommonExtProvider;
use crate::types::TokenIdType;
pub mod chat_completions; pub mod chat_completions;
pub mod common_ext; pub mod common_ext;
...@@ -211,6 +212,49 @@ impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T { ...@@ -211,6 +212,49 @@ impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
} }
} }
/// Converts a token string to its UTF-8 byte representation for OpenAI logprobs responses.
/// Returns `None` for empty tokens (unknown/unresolved tokens from the backend).
pub(crate) fn token_to_utf8_bytes(token: &str) -> Option<Vec<u8>> {
if token.is_empty() {
None
} else {
Some(token.as_bytes().to_vec())
}
}
/// Converts a list of internal backend `TopLogprob` entries into the OpenAI-compatible
/// `TopLogprobs` format. Ensures the selected token is present in the list.
pub(crate) fn convert_backend_top_logprobs(
top_lps: &[common::llm_backend::TopLogprob],
selected_token: &str,
selected_token_id: TokenIdType,
selected_logprob: f32,
) -> Vec<dynamo_async_openai::types::TopLogprobs> {
let mut found_selected = false;
let mut result: Vec<dynamo_async_openai::types::TopLogprobs> = top_lps
.iter()
.map(|top_lp| {
let tok = top_lp.token.clone().unwrap_or_default();
found_selected = found_selected || top_lp.token_id == selected_token_id;
let bytes = top_lp.bytes.clone().or_else(|| token_to_utf8_bytes(&tok));
dynamo_async_openai::types::TopLogprobs {
token: tok,
logprob: top_lp.logprob as f32,
bytes,
}
})
.collect();
if !found_selected {
result.push(dynamo_async_openai::types::TopLogprobs {
token: selected_token.to_string(),
logprob: selected_logprob,
bytes: token_to_utf8_bytes(selected_token),
});
}
result
}
pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>: pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
Send + 'static Send + 'static
{ {
......
...@@ -381,6 +381,7 @@ impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompleti ...@@ -381,6 +381,7 @@ impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompleti
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::openai::token_to_utf8_bytes;
use futures::stream; use futures::stream;
#[allow(deprecated)] #[allow(deprecated)]
...@@ -421,16 +422,19 @@ mod tests { ...@@ -421,16 +422,19 @@ mod tests {
refusal: None, refusal: None,
reasoning_content: None, reasoning_content: None,
}; };
let logprobs = logprob.map(|lp| dynamo_async_openai::types::ChatChoiceLogprobs { let logprobs = logprob.map(|lp| {
let token = text.to_string();
dynamo_async_openai::types::ChatChoiceLogprobs {
content: Some(vec![ content: Some(vec![
dynamo_async_openai::types::ChatCompletionTokenLogprob { dynamo_async_openai::types::ChatCompletionTokenLogprob {
token: text.to_string(), token: token.clone(),
logprob: lp, logprob: lp,
bytes: None, bytes: token_to_utf8_bytes(&token),
top_logprobs: vec![], top_logprobs: vec![],
}, },
]), ]),
refusal: None, refusal: None,
}
}); });
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
index, index,
......
...@@ -8,7 +8,11 @@ use crate::{ ...@@ -8,7 +8,11 @@ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
protocols::{ protocols::{
common::{self, timing::RequestTracker}, common::{self, timing::RequestTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo}, openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponse, TimingInfo},
token_to_utf8_bytes,
},
}, },
types::TokenIdType, types::TokenIdType,
}; };
...@@ -211,33 +215,12 @@ impl DeltaGenerator { ...@@ -211,33 +215,12 @@ impl DeltaGenerator {
.zip(tok_lps) .zip(tok_lps)
.zip(top_logprobs) .zip(top_logprobs)
.map(|(((t, tid), lp), top_lps)| { .map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false; let converted = convert_backend_top_logprobs(&top_lps, t, *tid, lp);
let mut converted_top_lps = top_lps
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
dynamo_async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
token: t.clone(),
logprob: lp,
bytes: None,
});
}
dynamo_async_openai::types::ChatCompletionTokenLogprob { dynamo_async_openai::types::ChatCompletionTokenLogprob {
token: t.clone(), token: t.clone(),
logprob: lp, logprob: lp,
bytes: None, bytes: token_to_utf8_bytes(t),
top_logprobs: converted_top_lps, top_logprobs: converted,
} }
}) })
.collect() .collect()
......
...@@ -7,7 +7,10 @@ use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; ...@@ -7,7 +7,10 @@ use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{ use crate::{
protocols::{ protocols::{
common::{self, timing::RequestTracker}, common::{self, timing::RequestTracker},
openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo}, openai::{
convert_backend_top_logprobs,
nvext::{NvExtProvider, NvExtResponse, TimingInfo},
},
}, },
types::TokenIdType, types::TokenIdType,
}; };
...@@ -172,29 +175,8 @@ impl DeltaGenerator { ...@@ -172,29 +175,8 @@ impl DeltaGenerator {
.zip(tok_lps.iter()) .zip(tok_lps.iter())
.zip(top_logprobs.iter()) .zip(top_logprobs.iter())
.map(|(((t, tid), lp), top_lps)| { .map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false; let converted = convert_backend_top_logprobs(top_lps, t, *tid, *lp);
let mut converted_top_lps = top_lps serde_json::to_value(converted).unwrap()
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
dynamo_async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
token: t.clone(),
logprob: *lp,
bytes: None,
});
}
serde_json::to_value(converted_top_lps).unwrap()
}) })
.collect() .collect()
}); });
......
...@@ -200,6 +200,33 @@ class ChatPayloadWithLogprobs(ChatPayload): ...@@ -200,6 +200,33 @@ class ChatPayloadWithLogprobs(ChatPayload):
logprob_val <= 0 logprob_val <= 0
), f"logprob should be <= 0, got {logprob_val}" ), f"logprob should be <= 0, got {logprob_val}"
# Validate bytes field is populated for the selected token
assert "bytes" in item, "Missing 'bytes' in logprobs content item"
token_str = item["token"]
if token_str:
assert (
item["bytes"] is not None
), f"'bytes' should be populated for non-empty token {token_str!r}"
assert isinstance(
item["bytes"], list
), f"'bytes' should be a list, got {type(item['bytes'])}"
# Validate top_logprobs entries have token, logprob, and bytes
for top_lp in item["top_logprobs"]:
assert (
"token" in top_lp
), "Missing 'token' in top_logprobs entry"
assert (
"logprob" in top_lp
), "Missing 'logprob' in top_logprobs entry"
assert (
"bytes" in top_lp
), "Missing 'bytes' in top_logprobs entry"
if top_lp["token"]:
assert (
top_lp["bytes"] is not None
), f"'bytes' should be populated for top_logprob token {top_lp['token']!r}"
logger.info( logger.info(
f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" f"✓ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs"
) )
...@@ -482,6 +509,26 @@ class CompletionPayloadWithLogprobs(CompletionPayload): ...@@ -482,6 +509,26 @@ class CompletionPayloadWithLogprobs(CompletionPayload):
logprob_val <= 0 logprob_val <= 0
), f"logprob at index {i} should be <= 0, got {logprob_val}" ), f"logprob at index {i} should be <= 0, got {logprob_val}"
# Validate top_logprobs entries have token, logprob, and bytes when present
top_logprobs_list = logprobs_data.get("top_logprobs", [])
for i, token_top_lps in enumerate(top_logprobs_list):
if not token_top_lps:
continue
for top_lp in token_top_lps:
assert (
"token" in top_lp
), f"Missing 'token' in top_logprobs[{i}] entry"
assert (
"logprob" in top_lp
), f"Missing 'logprob' in top_logprobs[{i}] entry"
assert (
"bytes" in top_lp
), f"Missing 'bytes' in top_logprobs[{i}] entry"
if top_lp["token"]:
assert (
top_lp["bytes"] is not None
), f"'bytes' should be populated for top_logprob token {top_lp['token']!r}"
logger.info( logger.info(
f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" f"✓ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment