Unverified Commit 45e881d3 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Support for field include_stop_str_in_output (#4924)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent ef3027bd
...@@ -118,6 +118,7 @@ fn build_backend_output(text: &str) -> BackendOutput { ...@@ -118,6 +118,7 @@ fn build_backend_output(text: &str) -> BackendOutput {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(common::FinishReason::Stop), finish_reason: Some(common::FinishReason::Stop),
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -285,6 +286,7 @@ async fn test_streaming_named_tool_buffers_until_finish() { ...@@ -285,6 +286,7 @@ async fn test_streaming_named_tool_buffers_until_finish() {
} else { } else {
None None
}, },
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -351,6 +353,7 @@ async fn test_streaming_required_tool_parallel() { ...@@ -351,6 +353,7 @@ async fn test_streaming_required_tool_parallel() {
} else { } else {
None None
}, },
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -419,6 +422,7 @@ fn test_no_tool_choice_outputs_normal_text() { ...@@ -419,6 +422,7 @@ fn test_no_tool_choice_outputs_normal_text() {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
......
...@@ -45,6 +45,7 @@ fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> ...@@ -45,6 +45,7 @@ fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) ->
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(finish), finish_reason: Some(finish),
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
......
...@@ -274,17 +274,22 @@ async fn handle_writer( ...@@ -274,17 +274,22 @@ async fn handle_writer(
alive_rx: tokio::sync::oneshot::Receiver<()>, alive_rx: tokio::sync::oneshot::Receiver<()>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> { ) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
// Only send sentinel for normal channel closure
let mut send_sentinel = true;
loop { loop {
let msg = tokio::select! { let msg = tokio::select! {
biased; biased;
_ = context.killed() => { _ = context.killed() => {
tracing::trace!("context kill signal received; shutting down"); tracing::trace!("context kill signal received; shutting down");
send_sentinel = false;
break; break;
} }
_ = context.stopped() => { _ = context.stopped() => {
tracing::trace!("context stop signal received; shutting down"); tracing::trace!("context stop signal received; shutting down");
send_sentinel = false;
break; break;
} }
...@@ -304,14 +309,17 @@ async fn handle_writer( ...@@ -304,14 +309,17 @@ async fn handle_writer(
"failed to send message to network; possible disconnect: {:?}", "failed to send message to network; possible disconnect: {:?}",
e e
); );
send_sentinel = false;
break; break;
} }
} }
// send sentinel message // Send sentinel only on normal closure
let message = serde_json::to_vec(&ControlMessage::Sentinel)?; if send_sentinel {
let msg = TwoPartMessage::from_header(message.into()); let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
framed_writer.send(msg).await?; let msg = TwoPartMessage::from_header(message.into());
framed_writer.send(msg).await?;
}
drop(alive_rx); drop(alive_rx);
Ok(framed_writer) Ok(framed_writer)
......
...@@ -60,6 +60,17 @@ vllm_configs = { ...@@ -60,6 +60,17 @@ vllm_configs = {
request_payloads=[ request_payloads=[
chat_payload_default(), chat_payload_default(),
completion_payload_default(), completion_payload_default(),
chat_payload(
"Can you write me a song?",
repeat_count=1,
expected_response=["song"],
temperature=0.0,
max_tokens=32,
extra_body={
"stop": ["song"],
"include_stop_str_in_output": True,
},
),
metric_payload_default(min_num_requests=6, backend="vllm"), metric_payload_default(min_num_requests=6, backend="vllm"),
], ],
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment