"examples/vllm_v0/vscode:/vscode.git/clone" did not exist on "5505507b12b1ae8008f39e54d7bce967246d6449"
Unverified Commit bce74588 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Rust to 1.89 and edition 2024 (#2659)

parent 268d017e
...@@ -102,18 +102,18 @@ impl KvManager { ...@@ -102,18 +102,18 @@ impl KvManager {
store: bool, store: bool,
parent_hash: Option<u64>, parent_hash: Option<u64>,
) { ) {
if let Some(ref tx) = self.move_block_response_tx { if let Some(ref tx) = self.move_block_response_tx
if !blocks.is_empty() { && !blocks.is_empty()
if reverse { {
blocks.reverse(); if reverse {
} blocks.reverse();
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
} }
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
} }
} }
...@@ -159,10 +159,10 @@ impl KvManager { ...@@ -159,10 +159,10 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1 // Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1); self.active_blocks.insert(hash.clone(), 1);
self.all_blocks.insert(hash.clone()); self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some() { if self.move_block_response_tx.is_some()
if let UniqueBlock::FullBlock(stored_full_block) = hash { && let UniqueBlock::FullBlock(stored_full_block) = hash
blocks_stored.push(*stored_full_block); {
} blocks_stored.push(*stored_full_block);
} }
} }
...@@ -184,10 +184,10 @@ impl KvManager { ...@@ -184,10 +184,10 @@ impl KvManager {
assert!(self.all_blocks.remove(hash)); assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending // Track blocks for batch sending
if self.move_block_response_tx.is_some() { if self.move_block_response_tx.is_some()
if let UniqueBlock::FullBlock(destroyed_full_block) = hash { && let UniqueBlock::FullBlock(destroyed_full_block) = hash
blocks_destroyed.push(*destroyed_full_block); {
} blocks_destroyed.push(*destroyed_full_block);
} }
} }
......
...@@ -160,58 +160,58 @@ impl MockEngineArgs { ...@@ -160,58 +160,58 @@ impl MockEngineArgs {
} }
// Apply each extra argument to the builder // Apply each extra argument to the builder
if let Some(value) = extra_args.get("num_gpu_blocks") { if let Some(value) = extra_args.get("num_gpu_blocks")
if let Some(num) = value.as_u64() { && let Some(num) = value.as_u64()
builder = builder.num_gpu_blocks(num as usize); {
} builder = builder.num_gpu_blocks(num as usize);
} }
if let Some(value) = extra_args.get("block_size") { if let Some(value) = extra_args.get("block_size")
if let Some(num) = value.as_u64() { && let Some(num) = value.as_u64()
builder = builder.block_size(num as usize); {
} builder = builder.block_size(num as usize);
} }
if let Some(value) = extra_args.get("max_num_seqs") { if let Some(value) = extra_args.get("max_num_seqs")
if let Some(num) = value.as_u64() { && let Some(num) = value.as_u64()
builder = builder.max_num_seqs(Some(num as usize)); {
} builder = builder.max_num_seqs(Some(num as usize));
} }
if let Some(value) = extra_args.get("max_num_batched_tokens") { if let Some(value) = extra_args.get("max_num_batched_tokens")
if let Some(num) = value.as_u64() { && let Some(num) = value.as_u64()
builder = builder.max_num_batched_tokens(Some(num as usize)); {
} builder = builder.max_num_batched_tokens(Some(num as usize));
} }
if let Some(value) = extra_args.get("enable_prefix_caching") { if let Some(value) = extra_args.get("enable_prefix_caching")
if let Some(enabled) = value.as_bool() { && let Some(enabled) = value.as_bool()
builder = builder.enable_prefix_caching(enabled); {
} builder = builder.enable_prefix_caching(enabled);
} }
if let Some(value) = extra_args.get("enable_chunked_prefill") { if let Some(value) = extra_args.get("enable_chunked_prefill")
if let Some(enabled) = value.as_bool() { && let Some(enabled) = value.as_bool()
builder = builder.enable_chunked_prefill(enabled); {
} builder = builder.enable_chunked_prefill(enabled);
} }
if let Some(value) = extra_args.get("watermark") { if let Some(value) = extra_args.get("watermark")
if let Some(num) = value.as_f64() { && let Some(num) = value.as_f64()
builder = builder.watermark(num); {
} builder = builder.watermark(num);
} }
if let Some(value) = extra_args.get("speedup_ratio") { if let Some(value) = extra_args.get("speedup_ratio")
if let Some(num) = value.as_f64() { && let Some(num) = value.as_f64()
builder = builder.speedup_ratio(num); {
} builder = builder.speedup_ratio(num);
} }
if let Some(value) = extra_args.get("dp_size") { if let Some(value) = extra_args.get("dp_size")
if let Some(num) = value.as_u64() { && let Some(num) = value.as_u64()
builder = builder.dp_size(num as u32); {
} builder = builder.dp_size(num as u32);
} }
// Build the MockEngineArgs with either defaults or overridden values // Build the MockEngineArgs with either defaults or overridden values
......
...@@ -31,15 +31,15 @@ ...@@ -31,15 +31,15 @@
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats}; use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats};
use crate::mocker::evictor::LRUEvictor; use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager; use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::{block_response_to_kv_event, MoveBlock, OutputSignal, PrefillCost};
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse}; use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event};
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::BlockHash; use crate::tokens::BlockHash;
use crate::tokens::blocks::UniqueBlock;
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
use tokio::time::Duration; use tokio::time::Duration;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
...@@ -414,9 +414,7 @@ impl Scheduler { ...@@ -414,9 +414,7 @@ impl Scheduler {
} }
// Drain KV events and forward to relay after prefill signal processing // Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() { while let Ok(event) = rx.try_recv() {
let _ = let _ =
relay_tx.send(block_response_to_kv_event(event, &block_hashes)); relay_tx.send(block_response_to_kv_event(event, &block_hashes));
...@@ -465,9 +463,7 @@ impl Scheduler { ...@@ -465,9 +463,7 @@ impl Scheduler {
} }
// Drain KV events and forward to relay after decode signal processing // Drain KV events and forward to relay after decode signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() { while let Ok(event) = rx.try_recv() {
let _ = relay_tx let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes())); .send(block_response_to_kv_event(event, &sequence.block_hashes()));
...@@ -663,7 +659,9 @@ fn process_signals( ...@@ -663,7 +659,9 @@ fn process_signals(
// Check we have a Use signal with blocks // Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else { let MoveBlock::Use(blocks) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"); panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
}; };
// Verify the signal contains exactly one block // Verify the signal contains exactly one block
...@@ -708,7 +706,7 @@ mod tests { ...@@ -708,7 +706,7 @@ mod tests {
#[case] enable_prefix_caching: bool, #[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool, #[case] enable_chunked_prefill: bool,
) { ) {
std::env::set_var("RUST_LOG", "debug"); unsafe { std::env::set_var("RUST_LOG", "debug") };
let kv_capacity: usize = 500; let kv_capacity: usize = 500;
let block_size: usize = 64; let block_size: usize = 64;
......
...@@ -568,7 +568,7 @@ mod tests { ...@@ -568,7 +568,7 @@ mod tests {
type TestTokenAlternative = (&'static str, f32); type TestTokenAlternative = (&'static str, f32);
type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>); type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
type TestTokenDataVec = Vec<TestTokenData>; type TestTokenDataVec = Vec<TestTokenData>;
use crate::perf::{record_stream_with_context, RecordingMode, TimestampedResponse}; use crate::perf::{RecordingMode, TimestampedResponse, record_stream_with_context};
use crate::protocols::codec::create_message_stream; use crate::protocols::codec::create_message_stream;
use crate::protocols::convert_sse_stream; use crate::protocols::convert_sse_stream;
use approx::assert_abs_diff_eq; use approx::assert_abs_diff_eq;
......
...@@ -28,21 +28,21 @@ use crate::tokenizers::Encoding; ...@@ -28,21 +28,21 @@ use crate::tokenizers::Encoding;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn, AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
}; };
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider}; use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
use crate::protocols::{ use crate::protocols::{
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
DeltaGeneratorExt,
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt,
}, },
}; };
use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer}; use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput}; use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
...@@ -487,11 +487,7 @@ impl ...@@ -487,11 +487,7 @@ impl
&self, &self,
request: SingleIn<NvCreateChatCompletionRequest>, request: SingleIn<NvCreateChatCompletionRequest>,
next: Arc< next: Arc<
dyn AsyncEngine< dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
Error,
>,
>, >,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
// unpack the request // unpack the request
...@@ -545,11 +541,7 @@ impl ...@@ -545,11 +541,7 @@ impl
&self, &self,
request: SingleIn<NvCreateCompletionRequest>, request: SingleIn<NvCreateCompletionRequest>,
next: Arc< next: Arc<
dyn AsyncEngine< dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
Error,
>,
>, >,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
// unpack the request // unpack the request
...@@ -603,10 +595,10 @@ impl ...@@ -603,10 +595,10 @@ impl
request: SingleIn<NvCreateEmbeddingRequest>, request: SingleIn<NvCreateEmbeddingRequest>,
next: Arc< next: Arc<
dyn AsyncEngine< dyn AsyncEngine<
SingleIn<PreprocessedEmbeddingRequest>, SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>, ManyOut<Annotated<EmbeddingsEngineOutput>>,
Error, Error,
>, >,
>, >,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
// Unpack request // Unpack request
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
use std::sync::Arc; use std::sync::Arc;
use super::tokcfg::{raise_exception, strftime_now, tojson, ChatTemplate}; use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment}; use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either; use either::Either;
use minijinja::Environment; use minijinja::Environment;
...@@ -60,7 +60,9 @@ impl HfTokenizerConfigJsonFormatter { ...@@ -60,7 +60,9 @@ impl HfTokenizerConfigJsonFormatter {
match &chat_template.0 { match &chat_template.0 {
Either::Left(x) => { Either::Left(x) => {
if x.contains("add_generation_prompt") { if x.contains("add_generation_prompt") {
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."); tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true); supports_add_generation_prompt = Some(true);
} }
env.add_template_owned("default", x.to_string())?; env.add_template_owned("default", x.to_string())?;
...@@ -72,11 +74,15 @@ impl HfTokenizerConfigJsonFormatter { ...@@ -72,11 +74,15 @@ impl HfTokenizerConfigJsonFormatter {
if v.contains("add_generation_prompt") { if v.contains("add_generation_prompt") {
match supports_add_generation_prompt { match supports_add_generation_prompt {
Some(true) | None => { Some(true) | None => {
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."); tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true); supports_add_generation_prompt = Some(true);
} }
Some(false) => { Some(false) => {
tracing::warn!("Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt."); tracing::warn!(
"Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt."
);
} }
} }
} else { } else {
...@@ -86,7 +92,9 @@ impl HfTokenizerConfigJsonFormatter { ...@@ -86,7 +92,9 @@ impl HfTokenizerConfigJsonFormatter {
} }
} }
if env.templates().count() == 0 { if env.templates().count() == 0 {
anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."); anyhow::bail!(
"Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."
);
} }
} }
} }
......
...@@ -21,7 +21,7 @@ use chrono::{DateTime, Local}; ...@@ -21,7 +21,7 @@ use chrono::{DateTime, Local};
use either::Either; use either::Either;
use ggus::{GGufMetaKV, GGufReader}; use ggus::{GGufMetaKV, GGufReader};
use memmap2::Mmap; use memmap2::Mmap;
use minijinja::{value::Kwargs, Error, ErrorKind, Value}; use minijinja::{Error, ErrorKind, Value, value::Kwargs};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[allow(dead_code)] #[allow(dead_code)]
......
...@@ -98,14 +98,14 @@ where ...@@ -98,14 +98,14 @@ where
fn try_from(value: Message) -> Result<Annotated<T>, Self::Error> { fn try_from(value: Message) -> Result<Annotated<T>, Self::Error> {
// determine if the message had an error // determine if the message had an error
if let Some(event) = value.event.as_ref() { if let Some(event) = value.event.as_ref()
if event == "error" { && event == "error"
let message = match &value.comments { {
Some(comments) => comments.join("\n"), let message = match &value.comments {
None => "`event: error` detected, but no error message found".to_string(), Some(comments) => comments.join("\n"),
}; None => "`event: error` detected, but no error message found".to_string(),
return Err(message); };
} return Err(message);
} }
// try to deserialize the data to T // try to deserialize the data to T
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason; pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::maybe_error::MaybeError; use dynamo_runtime::protocols::maybe_error::MaybeError;
......
...@@ -5,8 +5,8 @@ use anyhow::Result; ...@@ -5,8 +5,8 @@ use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{ use super::{
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider, ContentProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
}; };
use crate::protocols::openai::common_ext::CommonExtProvider; use crate::protocols::openai::common_ext::CommonExtProvider;
...@@ -20,7 +20,7 @@ pub mod responses; ...@@ -20,7 +20,7 @@ pub mod responses;
pub mod validate; pub mod validate;
use validate::{ use validate::{
validate_range, FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE, FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
}; };
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
...@@ -147,10 +147,10 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T { ...@@ -147,10 +147,10 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
let min_tokens = self.get_min_tokens(); let min_tokens = self.get_min_tokens();
let stop = self.get_stop(); let stop = self.get_stop();
if let Some(stop) = &stop { if let Some(stop) = &stop
if stop.len() > 4 { && stop.len() > 4
anyhow::bail!("stop conditions must be less than 4") {
} anyhow::bail!("stop conditions must be less than 4")
} }
// Use the trait method to get ignore_eos, which handles precedence // Use the trait method to get ignore_eos, which handles precedence
......
...@@ -20,11 +20,11 @@ use validator::Validate; ...@@ -20,11 +20,11 @@ use validator::Validate;
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use super::{ use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt, nvext::NvExt,
nvext::NvExtProvider, nvext::NvExtProvider,
validate, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, validate,
OpenAIStopConditionsProvider,
}; };
pub mod aggregator; pub mod aggregator;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
Annotated,
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, convert_sse_stream,
openai::ParsingOptions, openai::ParsingOptions,
Annotated,
}; };
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate; use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
...@@ -177,27 +165,26 @@ impl DeltaAggregator { ...@@ -177,27 +165,26 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax // After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() { for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() { if choice.tool_calls.is_none()
if let Ok(tool_calls) = try_tool_call_parse_aggregate( && let Ok(tool_calls) = try_tool_call_parse_aggregate(
&choice.text, &choice.text,
parsing_options.tool_call_parser.as_deref(), parsing_options.tool_call_parser.as_deref(),
) { )
if tool_calls.is_empty() { {
continue; if tool_calls.is_empty() {
} continue;
for tool_call in &tool_calls { }
tracing::debug!( for tool_call in &tool_calls {
tool_call_id = %tool_call.id, tracing::debug!(
function_name = %tool_call.function.name, tool_call_id = %tool_call.id,
arguments = %tool_call.function.arguments, function_name = %tool_call.function.name,
"Parsed structured tool call from aggregated content" arguments = %tool_call.function.arguments,
); "Parsed structured tool call from aggregated content"
} );
choice.tool_calls = Some(tool_calls);
choice.text.clear();
choice.finish_reason =
Some(dynamo_async_openai::types::FinishReason::ToolCalls);
} }
choice.tool_calls = Some(tool_calls);
choice.text.clear();
choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls);
} }
} }
......
...@@ -21,11 +21,12 @@ use validator::Validate; ...@@ -21,11 +21,12 @@ use validator::Validate;
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use super::{ use super::{
ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
validate, ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, validate,
OpenAIStopConditionsProvider,
}; };
mod aggregator; mod aggregator;
...@@ -87,12 +88,11 @@ impl NvExtProvider for NvCreateCompletionRequest { ...@@ -87,12 +88,11 @@ impl NvExtProvider for NvCreateCompletionRequest {
} }
fn raw_prompt(&self) -> Option<String> { fn raw_prompt(&self) -> Option<String> {
if let Some(nvext) = self.nvext.as_ref() { if let Some(nvext) = self.nvext.as_ref()
if let Some(use_raw_prompt) = nvext.use_raw_prompt { && let Some(use_raw_prompt) = nvext.use_raw_prompt
if use_raw_prompt { && use_raw_prompt
return Some(prompt_to_string(&self.inner.prompt)); {
} return Some(prompt_to_string(&self.inner.prompt));
}
} }
None None
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap; use std::collections::HashMap;
...@@ -20,11 +8,11 @@ use futures::{Stream, StreamExt}; ...@@ -20,11 +8,11 @@ use futures::{Stream, StreamExt};
use super::NvCreateCompletionResponse; use super::NvCreateCompletionResponse;
use crate::protocols::{ use crate::protocols::{
Annotated, DataStream,
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason, common::FinishReason,
convert_sse_stream, convert_sse_stream,
openai::ParsingOptions, openai::ParsingOptions,
Annotated, DataStream,
}; };
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`]. /// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
use super::NvCreateEmbeddingResponse; use super::NvCreateEmbeddingResponse;
use crate::protocols::{ use crate::protocols::{
Annotated,
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream,
}; };
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
...@@ -71,24 +72,23 @@ impl DeltaAggregator { ...@@ -71,24 +72,23 @@ impl DeltaAggregator {
} }
}; };
if aggregator.error.is_none() { if aggregator.error.is_none()
if let Some(response) = delta.data { && let Some(response) = delta.data
// For embeddings, we typically expect a single complete response {
// or we accumulate data from multiple responses // For embeddings, we typically expect a single complete response
match &mut aggregator.response { // or we accumulate data from multiple responses
Some(existing) => { match &mut aggregator.response {
// Merge embedding data if we have multiple responses Some(existing) => {
existing.inner.data.extend(response.inner.data); // Merge embedding data if we have multiple responses
existing.inner.data.extend(response.inner.data);
// Update usage statistics
existing.inner.usage.prompt_tokens += // Update usage statistics
response.inner.usage.prompt_tokens; existing.inner.usage.prompt_tokens +=
existing.inner.usage.total_tokens += response.inner.usage.prompt_tokens;
response.inner.usage.total_tokens; existing.inner.usage.total_tokens += response.inner.usage.total_tokens;
} }
None => { None => {
aggregator.response = Some(response); aggregator.response = Some(response);
}
} }
} }
} }
......
...@@ -93,30 +93,30 @@ pub const MAX_PROMPT_TOKEN_ID: u32 = 50256; ...@@ -93,30 +93,30 @@ pub const MAX_PROMPT_TOKEN_ID: u32 = 50256;
/// Validates the temperature parameter /// Validates the temperature parameter
pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(temp) = temperature { if let Some(temp) = temperature
if !(MIN_TEMPERATURE..=MAX_TEMPERATURE).contains(&temp) { && !(MIN_TEMPERATURE..=MAX_TEMPERATURE).contains(&temp)
anyhow::bail!( {
"Temperature must be between {} and {}, got {}", anyhow::bail!(
MIN_TEMPERATURE, "Temperature must be between {} and {}, got {}",
MAX_TEMPERATURE, MIN_TEMPERATURE,
temp MAX_TEMPERATURE,
); temp
} );
} }
Ok(()) Ok(())
} }
/// Validates the top_p parameter /// Validates the top_p parameter
pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = top_p { if let Some(p) = top_p
if !(MIN_TOP_P..=MAX_TOP_P).contains(&p) { && !(MIN_TOP_P..=MAX_TOP_P).contains(&p)
anyhow::bail!( {
"Top_p must be between {} and {}, got {}", anyhow::bail!(
MIN_TOP_P, "Top_p must be between {} and {}, got {}",
MAX_TOP_P, MIN_TOP_P,
p MAX_TOP_P,
); p
} );
} }
Ok(()) Ok(())
} }
...@@ -136,30 +136,30 @@ pub fn validate_temperature_top_p_exclusion( ...@@ -136,30 +136,30 @@ pub fn validate_temperature_top_p_exclusion(
/// Validates frequency penalty parameter /// Validates frequency penalty parameter
pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = frequency_penalty { if let Some(penalty) = frequency_penalty
if !(MIN_FREQUENCY_PENALTY..=MAX_FREQUENCY_PENALTY).contains(&penalty) { && !(MIN_FREQUENCY_PENALTY..=MAX_FREQUENCY_PENALTY).contains(&penalty)
anyhow::bail!( {
"Frequency penalty must be between {} and {}, got {}", anyhow::bail!(
MIN_FREQUENCY_PENALTY, "Frequency penalty must be between {} and {}, got {}",
MAX_FREQUENCY_PENALTY, MIN_FREQUENCY_PENALTY,
penalty MAX_FREQUENCY_PENALTY,
); penalty
} );
} }
Ok(()) Ok(())
} }
/// Validates presence penalty parameter /// Validates presence penalty parameter
pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = presence_penalty { if let Some(penalty) = presence_penalty
if !(MIN_PRESENCE_PENALTY..=MAX_PRESENCE_PENALTY).contains(&penalty) { && !(MIN_PRESENCE_PENALTY..=MAX_PRESENCE_PENALTY).contains(&penalty)
anyhow::bail!( {
"Presence penalty must be between {} and {}, got {}", anyhow::bail!(
MIN_PRESENCE_PENALTY, "Presence penalty must be between {} and {}, got {}",
MAX_PRESENCE_PENALTY, MIN_PRESENCE_PENALTY,
penalty MAX_PRESENCE_PENALTY,
); penalty
} );
} }
Ok(()) Ok(())
} }
...@@ -197,10 +197,10 @@ pub fn validate_logit_bias( ...@@ -197,10 +197,10 @@ pub fn validate_logit_bias(
/// Validates n parameter (number of choices) /// Validates n parameter (number of choices)
pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> { pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = n { if let Some(value) = n
if !(MIN_N..=MAX_N).contains(&value) { && !(MIN_N..=MAX_N).contains(&value)
anyhow::bail!("n must be between {} and {}, got {}", MIN_N, MAX_N, value); {
} anyhow::bail!("n must be between {} and {}, got {}", MIN_N, MAX_N, value);
} }
Ok(()) Ok(())
} }
...@@ -215,10 +215,10 @@ pub fn validate_model(model: &str) -> Result<(), anyhow::Error> { ...@@ -215,10 +215,10 @@ pub fn validate_model(model: &str) -> Result<(), anyhow::Error> {
/// Validates user parameter /// Validates user parameter
pub fn validate_user(user: Option<&str>) -> Result<(), anyhow::Error> { pub fn validate_user(user: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(user_id) = user { if let Some(user_id) = user
if user_id.trim().is_empty() { && user_id.trim().is_empty()
anyhow::bail!("User ID cannot be empty"); {
} anyhow::bail!("User ID cannot be empty");
} }
Ok(()) Ok(())
} }
...@@ -270,14 +270,14 @@ pub fn validate_messages( ...@@ -270,14 +270,14 @@ pub fn validate_messages(
/// Validates top_logprobs parameter /// Validates top_logprobs parameter
pub fn validate_top_logprobs(top_logprobs: Option<u8>) -> Result<(), anyhow::Error> { pub fn validate_top_logprobs(top_logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = top_logprobs { if let Some(value) = top_logprobs
if !(0..=20).contains(&value) { && !(0..=20).contains(&value)
anyhow::bail!( {
"Top_logprobs must be between 0 and {}, got {}", anyhow::bail!(
MAX_TOP_LOGPROBS, "Top_logprobs must be between 0 and {}, got {}",
value MAX_TOP_LOGPROBS,
); value
} );
} }
Ok(()) Ok(())
} }
...@@ -340,14 +340,14 @@ pub fn validate_metadata(metadata: &Option<serde_json::Value>) -> Result<(), any ...@@ -340,14 +340,14 @@ pub fn validate_metadata(metadata: &Option<serde_json::Value>) -> Result<(), any
); );
} }
if let Some(value_str) = value.as_str() { if let Some(value_str) = value.as_str()
if value_str.len() > MAX_METADATA_VALUE_LENGTH { && value_str.len() > MAX_METADATA_VALUE_LENGTH
anyhow::bail!( {
"Metadata value for key '{}' exceeds {} character limit", anyhow::bail!(
key, "Metadata value for key '{}' exceeds {} character limit",
MAX_METADATA_VALUE_LENGTH key,
); MAX_METADATA_VALUE_LENGTH
} );
} }
} }
} }
...@@ -438,14 +438,14 @@ pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<() ...@@ -438,14 +438,14 @@ pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<()
/// Validates logprobs parameter (for completion requests) /// Validates logprobs parameter (for completion requests)
pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> { pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = logprobs { if let Some(value) = logprobs
if !(MIN_LOGPROBS..=MAX_LOGPROBS).contains(&value) { && !(MIN_LOGPROBS..=MAX_LOGPROBS).contains(&value)
anyhow::bail!( {
"Logprobs must be between 0 and {}, got {}", anyhow::bail!(
MAX_LOGPROBS, "Logprobs must be between 0 and {}, got {}",
value MAX_LOGPROBS,
); value
} );
} }
Ok(()) Ok(())
} }
...@@ -461,14 +461,14 @@ pub fn validate_best_of(best_of: Option<u8>, n: Option<u8>) -> Result<(), anyhow ...@@ -461,14 +461,14 @@ pub fn validate_best_of(best_of: Option<u8>, n: Option<u8>) -> Result<(), anyhow
); );
} }
if let Some(n_value) = n { if let Some(n_value) = n
if best_of_value < n_value { && best_of_value < n_value
anyhow::bail!( {
"Best_of must be greater than or equal to n, got best_of={} and n={}", anyhow::bail!(
best_of_value, "Best_of must be greater than or equal to n, got best_of={} and n={}",
n_value best_of_value,
); n_value
} );
} }
} }
Ok(()) Ok(())
...@@ -487,10 +487,10 @@ pub fn validate_suffix(suffix: Option<&str>) -> Result<(), anyhow::Error> { ...@@ -487,10 +487,10 @@ pub fn validate_suffix(suffix: Option<&str>) -> Result<(), anyhow::Error> {
/// Validates max_tokens parameter /// Validates max_tokens parameter
pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error> { pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_tokens { if let Some(tokens) = max_tokens
if tokens == 0 { && tokens == 0
anyhow::bail!("Max tokens must be greater than 0, got {}", tokens); {
} anyhow::bail!("Max tokens must be greater than 0, got {}", tokens);
} }
Ok(()) Ok(())
} }
...@@ -499,13 +499,13 @@ pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error> ...@@ -499,13 +499,13 @@ pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error>
pub fn validate_max_completion_tokens( pub fn validate_max_completion_tokens(
max_completion_tokens: Option<u32>, max_completion_tokens: Option<u32>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_completion_tokens { if let Some(tokens) = max_completion_tokens
if tokens == 0 { && tokens == 0
anyhow::bail!( {
"Max completion tokens must be greater than 0, got {}", anyhow::bail!(
tokens "Max completion tokens must be greater than 0, got {}",
); tokens
} );
} }
Ok(()) Ok(())
} }
......
...@@ -8,7 +8,7 @@ use std::sync::Arc; ...@@ -8,7 +8,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::fs::{self, File, OpenOptions}; use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// Record entry that will be serialized to JSONL /// Record entry that will be serialized to JSONL
...@@ -70,10 +70,10 @@ where ...@@ -70,10 +70,10 @@ where
let first_event_time_clone = first_event_time.clone(); let first_event_time_clone = first_event_time.clone();
// Ensure the directory exists // Ensure the directory exists
if let Some(parent) = output_path.as_ref().parent() { if let Some(parent) = output_path.as_ref().parent()
if !parent.exists() { && !parent.exists()
fs::create_dir_all(parent).await?; {
} fs::create_dir_all(parent).await?;
} }
// Create the file for writing // Create the file for writing
...@@ -102,16 +102,16 @@ where ...@@ -102,16 +102,16 @@ where
loop { loop {
// Check time limit if set // Check time limit if set
if let Some(deadline) = max_time_deadline { if let Some(deadline) = max_time_deadline
if Instant::now() >= deadline { && Instant::now() >= deadline
tracing::info!("Recorder reached max time limit, shutting down"); {
// Flush and cancel tracing::info!("Recorder reached max time limit, shutting down");
if let Err(e) = writer.flush().await { // Flush and cancel
tracing::error!("Failed to flush on time limit shutdown: {}", e); if let Err(e) = writer.flush().await {
} tracing::error!("Failed to flush on time limit shutdown: {}", e);
cancel_clone.cancel();
return;
} }
cancel_clone.cancel();
return;
} }
tokio::select! { tokio::select! {
...@@ -170,8 +170,8 @@ where ...@@ -170,8 +170,8 @@ where
line_count += 1; line_count += 1;
// Check if we need to rotate to a new file // Check if we need to rotate to a new file
if let Some(max_lines) = max_lines_per_file { if let Some(max_lines) = max_lines_per_file
if line_count >= max_lines { && line_count >= max_lines {
// Flush the current file // Flush the current file
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
tracing::error!("Failed to flush file before rotation: {}", e); tracing::error!("Failed to flush file before rotation: {}", e);
...@@ -200,15 +200,14 @@ where ...@@ -200,15 +200,14 @@ where
} }
} }
} }
}
// Update event count // Update event count
let mut count = event_count_clone.lock().await; let mut count = event_count_clone.lock().await;
*count += 1; *count += 1;
// Check if we've reached the maximum count // Check if we've reached the maximum count
if let Some(max) = max_count { if let Some(max) = max_count
if *count >= max { && *count >= max {
tracing::info!("Recorder reached max event count ({}), shutting down", max); tracing::info!("Recorder reached max event count ({}), shutting down", max);
// Flush buffer before shutting down // Flush buffer before shutting down
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
...@@ -219,7 +218,6 @@ where ...@@ -219,7 +218,6 @@ where
cancel_clone.cancel(); cancel_clone.cancel();
return; return;
} }
}
} }
} }
} }
...@@ -307,19 +305,19 @@ where ...@@ -307,19 +305,19 @@ where
// Read and send events line by line // Read and send events line by line
while let Some(line) = lines.next_line().await? { while let Some(line) = lines.next_line().await? {
// Check if we've reached the maximum count // Check if we've reached the maximum count
if let Some(max) = max_count { if let Some(max) = max_count
if count >= max { && count >= max
tracing::info!("Reached maximum event count ({}), stopping", max); {
break; tracing::info!("Reached maximum event count ({}), stopping", max);
} break;
} }
// Check if we've exceeded the time limit // Check if we've exceeded the time limit
if let Some(end_time) = deadline { if let Some(end_time) = deadline
if Instant::now() >= end_time { && Instant::now() >= end_time
tracing::info!("Reached maximum time limit, stopping"); {
break; tracing::info!("Reached maximum time limit, stopping");
} break;
} }
line_number += 1; line_number += 1;
...@@ -346,12 +344,12 @@ where ...@@ -346,12 +344,12 @@ where
let event = record.event; let event = record.event;
// Handle timing if needed // Handle timing if needed
if timed && prev_timestamp.is_some() { if timed
let prev = prev_timestamp.unwrap(); && let Some(prev) = prev_timestamp
if timestamp > prev { && timestamp > prev
let wait_time = timestamp - prev; {
tokio::time::sleep(Duration::from_millis(wait_time)).await; let wait_time = timestamp - prev;
} tokio::time::sleep(Duration::from_millis(wait_time)).await;
} }
// Send the event // Send the event
...@@ -612,7 +610,10 @@ mod tests { ...@@ -612,7 +610,10 @@ mod tests {
// Should have MAX_LINES_PER_FILE lines in each file (except maybe the last one) // Should have MAX_LINES_PER_FILE lines in each file (except maybe the last one)
if i < found_files.len() - 1 { if i < found_files.len() - 1 {
assert_eq!(line_count, MAX_LINES_PER_FILE, "Each file except possibly the last should have exactly MAX_LINES_PER_FILE lines"); assert_eq!(
line_count, MAX_LINES_PER_FILE,
"Each file except possibly the last should have exactly MAX_LINES_PER_FILE lines"
);
} }
total_lines += line_count; total_lines += line_count;
...@@ -631,19 +632,19 @@ mod tests { ...@@ -631,19 +632,19 @@ mod tests {
line_number += 1; line_number += 1;
let entry: RecordEntry<TestEvent> = serde_json::from_str(&line).unwrap(); let entry: RecordEntry<TestEvent> = serde_json::from_str(&line).unwrap();
if let Some(prev) = prev_timestamp { if let Some(prev) = prev_timestamp
if entry.timestamp < prev { && entry.timestamp < prev
unsorted_count += 1; {
if unsorted_count <= 5 { unsorted_count += 1;
// Only log first 5 violations to avoid spam if unsorted_count <= 5 {
println!( // Only log first 5 violations to avoid spam
"Timestamp order violation in file {} at line {}: {} < {}", println!(
file_path.display(), "Timestamp order violation in file {} at line {}: {} < {}",
line_number, file_path.display(),
entry.timestamp, line_number,
prev entry.timestamp,
); prev
} );
} }
} }
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::{ use super::{
traits::{Decoder, Encoder, Tokenizer},
Encoding, Error, Result, TokenIdType, Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer},
}; };
pub struct HuggingFaceTokenizer { pub struct HuggingFaceTokenizer {
......
...@@ -603,7 +603,11 @@ impl TokenBlockSequence { ...@@ -603,7 +603,11 @@ impl TokenBlockSequence {
Some(range) => { Some(range) => {
// Since we only added one token, the range can only be empty or have one element. // Since we only added one token, the range can only be empty or have one element.
// If it's not empty, it must be `n..(n+1)`. // If it's not empty, it must be `n..(n+1)`.
assert_eq!(range.len(), 1, "Appending a single token completed more than one block, which should be impossible."); assert_eq!(
range.len(),
1,
"Appending a single token completed more than one block, which should be impossible."
);
Ok(Some(range.start)) Ok(Some(range.start))
} }
} }
...@@ -1108,7 +1112,7 @@ mod tests { ...@@ -1108,7 +1112,7 @@ mod tests {
let tokens2 = Tokens::from(vec![5, 6, 7, 8]); let tokens2 = Tokens::from(vec![5, 6, 7, 8]);
let chunk2 = TokenBlockChunk::new(tokens2.clone(), salt); let chunk2 = TokenBlockChunk::new(tokens2.clone(), salt);
let block2 = TokenBlock::from_chunk(chunk2, block1.parent_sequence_hash()); // Incorrect parent let block2 = TokenBlock::from_chunk(chunk2, block1.parent_sequence_hash()); // Incorrect parent
// Sequence hash should differ if parent is wrong // Sequence hash should differ if parent is wrong
assert_ne!(block2.sequence_hash(), SEQ_HASH_5_8); assert_ne!(block2.sequence_hash(), SEQ_HASH_5_8);
let chunk2_correct = TokenBlockChunk::new(tokens2.clone(), salt); let chunk2_correct = TokenBlockChunk::new(tokens2.clone(), salt);
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
// limitations under the License. // limitations under the License.
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError}, ContentProvider, DataStream,
codec::{Message, SseCodecError, create_message_stream},
openai::{ openai::{
chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse},
completions::NvCreateCompletionResponse,
ParsingOptions, ParsingOptions,
chat_completions::{NvCreateChatCompletionResponse, aggregator::ChatCompletionAggregator},
completions::NvCreateCompletionResponse,
}, },
ContentProvider, DataStream,
}; };
use futures::StreamExt; use futures::StreamExt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment