Unverified Commit bce74588 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Rust to 1.89 and edition 2024 (#2659)

parent 268d017e
......@@ -102,18 +102,18 @@ impl KvManager {
store: bool,
parent_hash: Option<u64>,
) {
if let Some(ref tx) = self.move_block_response_tx {
if !blocks.is_empty() {
if reverse {
blocks.reverse();
}
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
if let Some(ref tx) = self.move_block_response_tx
&& !blocks.is_empty()
{
if reverse {
blocks.reverse();
}
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
}
}
......@@ -159,10 +159,10 @@ impl KvManager {
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(stored_full_block) = hash {
blocks_stored.push(*stored_full_block);
}
if self.move_block_response_tx.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash
{
blocks_stored.push(*stored_full_block);
}
}
......@@ -184,10 +184,10 @@ impl KvManager {
assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(destroyed_full_block) = hash {
blocks_destroyed.push(*destroyed_full_block);
}
if self.move_block_response_tx.is_some()
&& let UniqueBlock::FullBlock(destroyed_full_block) = hash
{
blocks_destroyed.push(*destroyed_full_block);
}
}
......
......@@ -160,58 +160,58 @@ impl MockEngineArgs {
}
// Apply each extra argument to the builder
if let Some(value) = extra_args.get("num_gpu_blocks") {
if let Some(num) = value.as_u64() {
builder = builder.num_gpu_blocks(num as usize);
}
if let Some(value) = extra_args.get("num_gpu_blocks")
&& let Some(num) = value.as_u64()
{
builder = builder.num_gpu_blocks(num as usize);
}
if let Some(value) = extra_args.get("block_size") {
if let Some(num) = value.as_u64() {
builder = builder.block_size(num as usize);
}
if let Some(value) = extra_args.get("block_size")
&& let Some(num) = value.as_u64()
{
builder = builder.block_size(num as usize);
}
if let Some(value) = extra_args.get("max_num_seqs") {
if let Some(num) = value.as_u64() {
builder = builder.max_num_seqs(Some(num as usize));
}
if let Some(value) = extra_args.get("max_num_seqs")
&& let Some(num) = value.as_u64()
{
builder = builder.max_num_seqs(Some(num as usize));
}
if let Some(value) = extra_args.get("max_num_batched_tokens") {
if let Some(num) = value.as_u64() {
builder = builder.max_num_batched_tokens(Some(num as usize));
}
if let Some(value) = extra_args.get("max_num_batched_tokens")
&& let Some(num) = value.as_u64()
{
builder = builder.max_num_batched_tokens(Some(num as usize));
}
if let Some(value) = extra_args.get("enable_prefix_caching") {
if let Some(enabled) = value.as_bool() {
builder = builder.enable_prefix_caching(enabled);
}
if let Some(value) = extra_args.get("enable_prefix_caching")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_prefix_caching(enabled);
}
if let Some(value) = extra_args.get("enable_chunked_prefill") {
if let Some(enabled) = value.as_bool() {
builder = builder.enable_chunked_prefill(enabled);
}
if let Some(value) = extra_args.get("enable_chunked_prefill")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_chunked_prefill(enabled);
}
if let Some(value) = extra_args.get("watermark") {
if let Some(num) = value.as_f64() {
builder = builder.watermark(num);
}
if let Some(value) = extra_args.get("watermark")
&& let Some(num) = value.as_f64()
{
builder = builder.watermark(num);
}
if let Some(value) = extra_args.get("speedup_ratio") {
if let Some(num) = value.as_f64() {
builder = builder.speedup_ratio(num);
}
if let Some(value) = extra_args.get("speedup_ratio")
&& let Some(num) = value.as_f64()
{
builder = builder.speedup_ratio(num);
}
if let Some(value) = extra_args.get("dp_size") {
if let Some(num) = value.as_u64() {
builder = builder.dp_size(num as u32);
}
if let Some(value) = extra_args.get("dp_size")
&& let Some(num) = value.as_u64()
{
builder = builder.dp_size(num as u32);
}
// Build the MockEngineArgs with either defaults or overridden values
......
......@@ -31,15 +31,15 @@
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats};
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::{block_response_to_kv_event, MoveBlock, OutputSignal, PrefillCost};
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event};
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::BlockHash;
use crate::tokens::blocks::UniqueBlock;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{Mutex, mpsc};
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
......@@ -414,9 +414,7 @@ impl Scheduler {
}
// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) =
(&kv_events_tx, &mut block_resp_rx)
{
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ =
relay_tx.send(block_response_to_kv_event(event, &block_hashes));
......@@ -465,9 +463,7 @@ impl Scheduler {
}
// Drain KV events and forward to relay after decode signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) =
(&kv_events_tx, &mut block_resp_rx)
{
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes()));
......@@ -663,7 +659,9 @@ fn process_signals(
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}");
panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
};
// Verify the signal contains exactly one block
......@@ -708,7 +706,7 @@ mod tests {
#[case] enable_prefix_caching: bool,
#[case] enable_chunked_prefill: bool,
) {
std::env::set_var("RUST_LOG", "debug");
unsafe { std::env::set_var("RUST_LOG", "debug") };
let kv_capacity: usize = 500;
let block_size: usize = 64;
......
......@@ -568,7 +568,7 @@ mod tests {
type TestTokenAlternative = (&'static str, f32);
type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
type TestTokenDataVec = Vec<TestTokenData>;
use crate::perf::{record_stream_with_context, RecordingMode, TimestampedResponse};
use crate::perf::{RecordingMode, TimestampedResponse, record_stream_with_context};
use crate::protocols::codec::create_message_stream;
use crate::protocols::convert_sse_stream;
use approx::assert_abs_diff_eq;
......
......@@ -28,21 +28,21 @@ use crate::tokenizers::Encoding;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{
async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn,
AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
};
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
use crate::protocols::{
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{
DeltaGeneratorExt,
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
nvext::NvExtProvider,
DeltaGeneratorExt,
},
};
use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
......@@ -487,11 +487,7 @@ impl
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
next: Arc<
dyn AsyncEngine<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
Error,
>,
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
// unpack the request
......@@ -545,11 +541,7 @@ impl
&self,
request: SingleIn<NvCreateCompletionRequest>,
next: Arc<
dyn AsyncEngine<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
Error,
>,
dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
// unpack the request
......@@ -603,10 +595,10 @@ impl
request: SingleIn<NvCreateEmbeddingRequest>,
next: Arc<
dyn AsyncEngine<
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
Error,
>,
SingleIn<PreprocessedEmbeddingRequest>,
ManyOut<Annotated<EmbeddingsEngineOutput>>,
Error,
>,
>,
) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
// Unpack request
......
......@@ -15,7 +15,7 @@
use std::sync::Arc;
use super::tokcfg::{raise_exception, strftime_now, tojson, ChatTemplate};
use super::tokcfg::{ChatTemplate, raise_exception, strftime_now, tojson};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either;
use minijinja::Environment;
......@@ -60,7 +60,9 @@ impl HfTokenizerConfigJsonFormatter {
match &chat_template.0 {
Either::Left(x) => {
if x.contains("add_generation_prompt") {
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true);
}
env.add_template_owned("default", x.to_string())?;
......@@ -72,11 +74,15 @@ impl HfTokenizerConfigJsonFormatter {
if v.contains("add_generation_prompt") {
match supports_add_generation_prompt {
Some(true) | None => {
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
tracing::debug!(
"Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt."
);
supports_add_generation_prompt = Some(true);
}
Some(false) => {
tracing::warn!("Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt.");
tracing::warn!(
"Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt."
);
}
}
} else {
......@@ -86,7 +92,9 @@ impl HfTokenizerConfigJsonFormatter {
}
}
if env.templates().count() == 0 {
anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
anyhow::bail!(
"Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools."
);
}
}
}
......
......@@ -21,7 +21,7 @@ use chrono::{DateTime, Local};
use either::Either;
use ggus::{GGufMetaKV, GGufReader};
use memmap2::Mmap;
use minijinja::{value::Kwargs, Error, ErrorKind, Value};
use minijinja::{Error, ErrorKind, Value, value::Kwargs};
use serde::{Deserialize, Serialize};
#[allow(dead_code)]
......
......@@ -98,14 +98,14 @@ where
fn try_from(value: Message) -> Result<Annotated<T>, Self::Error> {
// determine if the message had an error
if let Some(event) = value.event.as_ref() {
if event == "error" {
let message = match &value.comments {
Some(comments) => comments.join("\n"),
None => "`event: error` detected, but no error message found".to_string(),
};
return Err(message);
}
if let Some(event) = value.event.as_ref()
&& event == "error"
{
let message = match &value.comments {
Some(comments) => comments.join("\n"),
None => "`event: error` detected, but no error message found".to_string(),
};
return Err(message);
}
// try to deserialize the data to T
......
......@@ -15,8 +15,8 @@
use serde::{Deserialize, Serialize};
pub use super::preprocessor::PreprocessedRequest;
pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::maybe_error::MaybeError;
......
......@@ -5,8 +5,8 @@ use anyhow::Result;
use serde::{Deserialize, Serialize};
use super::{
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
};
use crate::protocols::openai::common_ext::CommonExtProvider;
......@@ -20,7 +20,7 @@ pub mod responses;
pub mod validate;
use validate::{
validate_range, FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE,
FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
};
#[derive(Serialize, Deserialize, Debug)]
......@@ -147,10 +147,10 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
let min_tokens = self.get_min_tokens();
let stop = self.get_stop();
if let Some(stop) = &stop {
if stop.len() > 4 {
anyhow::bail!("stop conditions must be less than 4")
}
if let Some(stop) = &stop
&& stop.len() > 4
{
anyhow::bail!("stop conditions must be less than 4")
}
// Use the trait method to get ignore_eos, which handles precedence
......
......@@ -20,11 +20,11 @@ use validator::Validate;
use crate::engines::ValidateRequest;
use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt,
nvext::NvExtProvider,
validate, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
validate,
};
pub mod aggregator;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::{Stream, StreamExt};
use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
Annotated,
codec::{Message, SseCodecError},
convert_sse_stream,
openai::ParsingOptions,
Annotated,
};
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
......@@ -177,27 +165,26 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(
if choice.tool_calls.is_none()
&& let Ok(tool_calls) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
) {
if tool_calls.is_empty() {
continue;
}
for tool_call in &tool_calls {
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from aggregated content"
);
}
choice.tool_calls = Some(tool_calls);
choice.text.clear();
choice.finish_reason =
Some(dynamo_async_openai::types::FinishReason::ToolCalls);
)
{
if tool_calls.is_empty() {
continue;
}
for tool_call in &tool_calls {
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from aggregated content"
);
}
choice.tool_calls = Some(tool_calls);
choice.text.clear();
choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls);
}
}
......
......@@ -21,11 +21,12 @@ use validator::Validate;
use crate::engines::ValidateRequest;
use super::{
ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{CommonExt, CommonExtProvider},
nvext::{NvExt, NvExtProvider},
validate, ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
validate,
};
mod aggregator;
......@@ -87,12 +88,11 @@ impl NvExtProvider for NvCreateCompletionRequest {
}
fn raw_prompt(&self) -> Option<String> {
if let Some(nvext) = self.nvext.as_ref() {
if let Some(use_raw_prompt) = nvext.use_raw_prompt {
if use_raw_prompt {
return Some(prompt_to_string(&self.inner.prompt));
}
}
if let Some(nvext) = self.nvext.as_ref()
&& let Some(use_raw_prompt) = nvext.use_raw_prompt
&& use_raw_prompt
{
return Some(prompt_to_string(&self.inner.prompt));
}
None
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
......@@ -20,11 +8,11 @@ use futures::{Stream, StreamExt};
use super::NvCreateCompletionResponse;
use crate::protocols::{
Annotated, DataStream,
codec::{Message, SseCodecError},
common::FinishReason,
convert_sse_stream,
openai::ParsingOptions,
Annotated, DataStream,
};
/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
......
......@@ -15,8 +15,9 @@
use super::NvCreateEmbeddingResponse;
use crate::protocols::{
Annotated,
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
convert_sse_stream,
};
use dynamo_runtime::engine::DataStream;
......@@ -71,24 +72,23 @@ impl DeltaAggregator {
}
};
if aggregator.error.is_none() {
if let Some(response) = delta.data {
// For embeddings, we typically expect a single complete response
// or we accumulate data from multiple responses
match &mut aggregator.response {
Some(existing) => {
// Merge embedding data if we have multiple responses
existing.inner.data.extend(response.inner.data);
// Update usage statistics
existing.inner.usage.prompt_tokens +=
response.inner.usage.prompt_tokens;
existing.inner.usage.total_tokens +=
response.inner.usage.total_tokens;
}
None => {
aggregator.response = Some(response);
}
if aggregator.error.is_none()
&& let Some(response) = delta.data
{
// For embeddings, we typically expect a single complete response
// or we accumulate data from multiple responses
match &mut aggregator.response {
Some(existing) => {
// Merge embedding data if we have multiple responses
existing.inner.data.extend(response.inner.data);
// Update usage statistics
existing.inner.usage.prompt_tokens +=
response.inner.usage.prompt_tokens;
existing.inner.usage.total_tokens += response.inner.usage.total_tokens;
}
None => {
aggregator.response = Some(response);
}
}
}
......
......@@ -93,30 +93,30 @@ pub const MAX_PROMPT_TOKEN_ID: u32 = 50256;
/// Validates the temperature parameter
pub fn validate_temperature(temperature: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(temp) = temperature {
if !(MIN_TEMPERATURE..=MAX_TEMPERATURE).contains(&temp) {
anyhow::bail!(
"Temperature must be between {} and {}, got {}",
MIN_TEMPERATURE,
MAX_TEMPERATURE,
temp
);
}
if let Some(temp) = temperature
&& !(MIN_TEMPERATURE..=MAX_TEMPERATURE).contains(&temp)
{
anyhow::bail!(
"Temperature must be between {} and {}, got {}",
MIN_TEMPERATURE,
MAX_TEMPERATURE,
temp
);
}
Ok(())
}
/// Validates the top_p parameter
pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = top_p {
if !(MIN_TOP_P..=MAX_TOP_P).contains(&p) {
anyhow::bail!(
"Top_p must be between {} and {}, got {}",
MIN_TOP_P,
MAX_TOP_P,
p
);
}
if let Some(p) = top_p
&& !(MIN_TOP_P..=MAX_TOP_P).contains(&p)
{
anyhow::bail!(
"Top_p must be between {} and {}, got {}",
MIN_TOP_P,
MAX_TOP_P,
p
);
}
Ok(())
}
......@@ -136,30 +136,30 @@ pub fn validate_temperature_top_p_exclusion(
/// Validates frequency penalty parameter
pub fn validate_frequency_penalty(frequency_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = frequency_penalty {
if !(MIN_FREQUENCY_PENALTY..=MAX_FREQUENCY_PENALTY).contains(&penalty) {
anyhow::bail!(
"Frequency penalty must be between {} and {}, got {}",
MIN_FREQUENCY_PENALTY,
MAX_FREQUENCY_PENALTY,
penalty
);
}
if let Some(penalty) = frequency_penalty
&& !(MIN_FREQUENCY_PENALTY..=MAX_FREQUENCY_PENALTY).contains(&penalty)
{
anyhow::bail!(
"Frequency penalty must be between {} and {}, got {}",
MIN_FREQUENCY_PENALTY,
MAX_FREQUENCY_PENALTY,
penalty
);
}
Ok(())
}
/// Validates presence penalty parameter
pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(penalty) = presence_penalty {
if !(MIN_PRESENCE_PENALTY..=MAX_PRESENCE_PENALTY).contains(&penalty) {
anyhow::bail!(
"Presence penalty must be between {} and {}, got {}",
MIN_PRESENCE_PENALTY,
MAX_PRESENCE_PENALTY,
penalty
);
}
if let Some(penalty) = presence_penalty
&& !(MIN_PRESENCE_PENALTY..=MAX_PRESENCE_PENALTY).contains(&penalty)
{
anyhow::bail!(
"Presence penalty must be between {} and {}, got {}",
MIN_PRESENCE_PENALTY,
MAX_PRESENCE_PENALTY,
penalty
);
}
Ok(())
}
......@@ -197,10 +197,10 @@ pub fn validate_logit_bias(
/// Validates n parameter (number of choices)
pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = n {
if !(MIN_N..=MAX_N).contains(&value) {
anyhow::bail!("n must be between {} and {}, got {}", MIN_N, MAX_N, value);
}
if let Some(value) = n
&& !(MIN_N..=MAX_N).contains(&value)
{
anyhow::bail!("n must be between {} and {}, got {}", MIN_N, MAX_N, value);
}
Ok(())
}
......@@ -215,10 +215,10 @@ pub fn validate_model(model: &str) -> Result<(), anyhow::Error> {
/// Validates user parameter
pub fn validate_user(user: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(user_id) = user {
if user_id.trim().is_empty() {
anyhow::bail!("User ID cannot be empty");
}
if let Some(user_id) = user
&& user_id.trim().is_empty()
{
anyhow::bail!("User ID cannot be empty");
}
Ok(())
}
......@@ -270,14 +270,14 @@ pub fn validate_messages(
/// Validates top_logprobs parameter
pub fn validate_top_logprobs(top_logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = top_logprobs {
if !(0..=20).contains(&value) {
anyhow::bail!(
"Top_logprobs must be between 0 and {}, got {}",
MAX_TOP_LOGPROBS,
value
);
}
if let Some(value) = top_logprobs
&& !(0..=20).contains(&value)
{
anyhow::bail!(
"Top_logprobs must be between 0 and {}, got {}",
MAX_TOP_LOGPROBS,
value
);
}
Ok(())
}
......@@ -340,14 +340,14 @@ pub fn validate_metadata(metadata: &Option<serde_json::Value>) -> Result<(), any
);
}
if let Some(value_str) = value.as_str() {
if value_str.len() > MAX_METADATA_VALUE_LENGTH {
anyhow::bail!(
"Metadata value for key '{}' exceeds {} character limit",
key,
MAX_METADATA_VALUE_LENGTH
);
}
if let Some(value_str) = value.as_str()
&& value_str.len() > MAX_METADATA_VALUE_LENGTH
{
anyhow::bail!(
"Metadata value for key '{}' exceeds {} character limit",
key,
MAX_METADATA_VALUE_LENGTH
);
}
}
}
......@@ -438,14 +438,14 @@ pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<()
/// Validates logprobs parameter (for completion requests)
pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = logprobs {
if !(MIN_LOGPROBS..=MAX_LOGPROBS).contains(&value) {
anyhow::bail!(
"Logprobs must be between 0 and {}, got {}",
MAX_LOGPROBS,
value
);
}
if let Some(value) = logprobs
&& !(MIN_LOGPROBS..=MAX_LOGPROBS).contains(&value)
{
anyhow::bail!(
"Logprobs must be between 0 and {}, got {}",
MAX_LOGPROBS,
value
);
}
Ok(())
}
......@@ -461,14 +461,14 @@ pub fn validate_best_of(best_of: Option<u8>, n: Option<u8>) -> Result<(), anyhow
);
}
if let Some(n_value) = n {
if best_of_value < n_value {
anyhow::bail!(
"Best_of must be greater than or equal to n, got best_of={} and n={}",
best_of_value,
n_value
);
}
if let Some(n_value) = n
&& best_of_value < n_value
{
anyhow::bail!(
"Best_of must be greater than or equal to n, got best_of={} and n={}",
best_of_value,
n_value
);
}
}
Ok(())
......@@ -487,10 +487,10 @@ pub fn validate_suffix(suffix: Option<&str>) -> Result<(), anyhow::Error> {
/// Validates max_tokens parameter
pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_tokens {
if tokens == 0 {
anyhow::bail!("Max tokens must be greater than 0, got {}", tokens);
}
if let Some(tokens) = max_tokens
&& tokens == 0
{
anyhow::bail!("Max tokens must be greater than 0, got {}", tokens);
}
Ok(())
}
......@@ -499,13 +499,13 @@ pub fn validate_max_tokens(max_tokens: Option<u32>) -> Result<(), anyhow::Error>
pub fn validate_max_completion_tokens(
max_completion_tokens: Option<u32>,
) -> Result<(), anyhow::Error> {
if let Some(tokens) = max_completion_tokens {
if tokens == 0 {
anyhow::bail!(
"Max completion tokens must be greater than 0, got {}",
tokens
);
}
if let Some(tokens) = max_completion_tokens
&& tokens == 0
{
anyhow::bail!(
"Max completion tokens must be greater than 0, got {}",
tokens
);
}
Ok(())
}
......
......@@ -8,7 +8,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken;
/// Record entry that will be serialized to JSONL
......@@ -70,10 +70,10 @@ where
let first_event_time_clone = first_event_time.clone();
// Ensure the directory exists
if let Some(parent) = output_path.as_ref().parent() {
if !parent.exists() {
fs::create_dir_all(parent).await?;
}
if let Some(parent) = output_path.as_ref().parent()
&& !parent.exists()
{
fs::create_dir_all(parent).await?;
}
// Create the file for writing
......@@ -102,16 +102,16 @@ where
loop {
// Check time limit if set
if let Some(deadline) = max_time_deadline {
if Instant::now() >= deadline {
tracing::info!("Recorder reached max time limit, shutting down");
// Flush and cancel
if let Err(e) = writer.flush().await {
tracing::error!("Failed to flush on time limit shutdown: {}", e);
}
cancel_clone.cancel();
return;
if let Some(deadline) = max_time_deadline
&& Instant::now() >= deadline
{
tracing::info!("Recorder reached max time limit, shutting down");
// Flush and cancel
if let Err(e) = writer.flush().await {
tracing::error!("Failed to flush on time limit shutdown: {}", e);
}
cancel_clone.cancel();
return;
}
tokio::select! {
......@@ -170,8 +170,8 @@ where
line_count += 1;
// Check if we need to rotate to a new file
if let Some(max_lines) = max_lines_per_file {
if line_count >= max_lines {
if let Some(max_lines) = max_lines_per_file
&& line_count >= max_lines {
// Flush the current file
if let Err(e) = writer.flush().await {
tracing::error!("Failed to flush file before rotation: {}", e);
......@@ -200,15 +200,14 @@ where
}
}
}
}
// Update event count
let mut count = event_count_clone.lock().await;
*count += 1;
// Check if we've reached the maximum count
if let Some(max) = max_count {
if *count >= max {
if let Some(max) = max_count
&& *count >= max {
tracing::info!("Recorder reached max event count ({}), shutting down", max);
// Flush buffer before shutting down
if let Err(e) = writer.flush().await {
......@@ -219,7 +218,6 @@ where
cancel_clone.cancel();
return;
}
}
}
}
}
......@@ -307,19 +305,19 @@ where
// Read and send events line by line
while let Some(line) = lines.next_line().await? {
// Check if we've reached the maximum count
if let Some(max) = max_count {
if count >= max {
tracing::info!("Reached maximum event count ({}), stopping", max);
break;
}
if let Some(max) = max_count
&& count >= max
{
tracing::info!("Reached maximum event count ({}), stopping", max);
break;
}
// Check if we've exceeded the time limit
if let Some(end_time) = deadline {
if Instant::now() >= end_time {
tracing::info!("Reached maximum time limit, stopping");
break;
}
if let Some(end_time) = deadline
&& Instant::now() >= end_time
{
tracing::info!("Reached maximum time limit, stopping");
break;
}
line_number += 1;
......@@ -346,12 +344,12 @@ where
let event = record.event;
// Handle timing if needed
if timed && prev_timestamp.is_some() {
let prev = prev_timestamp.unwrap();
if timestamp > prev {
let wait_time = timestamp - prev;
tokio::time::sleep(Duration::from_millis(wait_time)).await;
}
if timed
&& let Some(prev) = prev_timestamp
&& timestamp > prev
{
let wait_time = timestamp - prev;
tokio::time::sleep(Duration::from_millis(wait_time)).await;
}
// Send the event
......@@ -612,7 +610,10 @@ mod tests {
// Should have MAX_LINES_PER_FILE lines in each file (except maybe the last one)
if i < found_files.len() - 1 {
assert_eq!(line_count, MAX_LINES_PER_FILE, "Each file except possibly the last should have exactly MAX_LINES_PER_FILE lines");
assert_eq!(
line_count, MAX_LINES_PER_FILE,
"Each file except possibly the last should have exactly MAX_LINES_PER_FILE lines"
);
}
total_lines += line_count;
......@@ -631,19 +632,19 @@ mod tests {
line_number += 1;
let entry: RecordEntry<TestEvent> = serde_json::from_str(&line).unwrap();
if let Some(prev) = prev_timestamp {
if entry.timestamp < prev {
unsorted_count += 1;
if unsorted_count <= 5 {
// Only log first 5 violations to avoid spam
println!(
"Timestamp order violation in file {} at line {}: {} < {}",
file_path.display(),
line_number,
entry.timestamp,
prev
);
}
if let Some(prev) = prev_timestamp
&& entry.timestamp < prev
{
unsorted_count += 1;
if unsorted_count <= 5 {
// Only log first 5 violations to avoid spam
println!(
"Timestamp order violation in file {} at line {}: {} < {}",
file_path.display(),
line_number,
entry.timestamp,
prev
);
}
}
......
......@@ -16,8 +16,8 @@
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::{
traits::{Decoder, Encoder, Tokenizer},
Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer},
};
pub struct HuggingFaceTokenizer {
......
......@@ -603,7 +603,11 @@ impl TokenBlockSequence {
Some(range) => {
// Since we only added one token, the range can only be empty or have one element.
// If it's not empty, it must be `n..(n+1)`.
assert_eq!(range.len(), 1, "Appending a single token completed more than one block, which should be impossible.");
assert_eq!(
range.len(),
1,
"Appending a single token completed more than one block, which should be impossible."
);
Ok(Some(range.start))
}
}
......@@ -1108,7 +1112,7 @@ mod tests {
let tokens2 = Tokens::from(vec![5, 6, 7, 8]);
let chunk2 = TokenBlockChunk::new(tokens2.clone(), salt);
let block2 = TokenBlock::from_chunk(chunk2, block1.parent_sequence_hash()); // Incorrect parent
// Sequence hash should differ if parent is wrong
// Sequence hash should differ if parent is wrong
assert_ne!(block2.sequence_hash(), SEQ_HASH_5_8);
let chunk2_correct = TokenBlockChunk::new(tokens2.clone(), salt);
......
......@@ -14,13 +14,13 @@
// limitations under the License.
use dynamo_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError},
ContentProvider, DataStream,
codec::{Message, SseCodecError, create_message_stream},
openai::{
chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse},
completions::NvCreateCompletionResponse,
ParsingOptions,
chat_completions::{NvCreateChatCompletionResponse, aggregator::ChatCompletionAggregator},
completions::NvCreateCompletionResponse,
},
ContentProvider, DataStream,
};
use futures::StreamExt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment