Unverified Commit bce74588 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Rust to 1.89 and edition 2024 (#2659)

parent 268d017e
...@@ -29,8 +29,8 @@ use std::collections::HashMap; ...@@ -29,8 +29,8 @@ use std::collections::HashMap;
use anyhow::Context; use anyhow::Context;
use candle_core::{ use candle_core::{
quantized::gguf_file::{self, Value},
Result, Result,
quantized::gguf_file::{self, Value},
}; };
use tracing::info; use tracing::info;
...@@ -66,7 +66,9 @@ impl Content { ...@@ -66,7 +66,9 @@ impl Content {
accum accum
}); });
if n_splits.len() > 1 { if n_splits.len() > 1 {
candle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?"); candle_core::bail!(
"GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?"
);
} }
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
if !n_splits.is_empty() && n_readers != n_splits[0] as usize { if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
// SOFTWARE. // SOFTWARE.
use akin::akin; use akin::akin;
use anyhow::ensure;
use anyhow::Result; use anyhow::Result;
use anyhow::ensure;
use candle_core::quantized::gguf_file; use candle_core::quantized::gguf_file;
use std::collections::HashMap; use std::collections::HashMap;
use tracing::warn; use tracing::warn;
......
...@@ -31,6 +31,7 @@ use ahash::AHashMap; ...@@ -31,6 +31,7 @@ use ahash::AHashMap;
use anyhow::Result; use anyhow::Result;
use itertools::Itertools; use itertools::Itertools;
use tokenizers::{ use tokenizers::{
AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
decoders::{ decoders::{
self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip, self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
}, },
...@@ -41,7 +42,6 @@ use tokenizers::{ ...@@ -41,7 +42,6 @@ use tokenizers::{
self, self,
template::{self, TemplateProcessing}, template::{self, TemplateProcessing},
}, },
AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
}; };
use tracing::info; use tracing::info;
...@@ -402,7 +402,7 @@ impl TryFrom<Normalizer<'_>> for NormalizerWrapper { ...@@ -402,7 +402,7 @@ impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::Result; use anyhow::Result;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[allow(dead_code)] #[allow(dead_code)]
......
...@@ -14,7 +14,7 @@ use std::time::Instant; ...@@ -14,7 +14,7 @@ use std::time::Instant;
use async_trait::async_trait; use async_trait::async_trait;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use dynamo_async_openai::{config::OpenAIConfig, error::OpenAIError, Client}; use dynamo_async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
use futures::Stream; use futures::Stream;
use serde_json::Value; use serde_json::Value;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -22,10 +22,10 @@ use tracing; ...@@ -22,10 +22,10 @@ use tracing;
use uuid::Uuid; use uuid::Uuid;
// Import our existing recording infrastructure // Import our existing recording infrastructure
use crate::protocols::Annotated;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use crate::protocols::Annotated;
use dynamo_runtime::engine::{ use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream, AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
}; };
...@@ -523,7 +523,7 @@ impl GenericBYOTClient { ...@@ -523,7 +523,7 @@ impl GenericBYOTClient {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tokio::time::{sleep, Duration}; use tokio::time::{Duration, sleep};
#[tokio::test] #[tokio::test]
async fn test_http_request_context_creation() { async fn test_http_request_context_creation() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{service_v2, RouteDoc}; use super::{RouteDoc, service_v2};
use axum::{http::Method, http::StatusCode, response::IntoResponse, routing::get, Json, Router}; use axum::{Json, Router, http::Method, http::StatusCode, response::IntoResponse, routing::get};
use dynamo_runtime::instances::list_all_instances; use dynamo_runtime::instances::list_all_instances;
use serde_json::json; use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router}; use axum::{Router, extract::State, http::StatusCode, response::IntoResponse, routing::get};
use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts}; use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts};
use std::{ use std::{
sync::Arc, sync::Arc,
...@@ -538,7 +538,7 @@ async fn handler_metrics(State(registry): State<Arc<Registry>>) -> impl IntoResp ...@@ -538,7 +538,7 @@ async fn handler_metrics(State(registry): State<Arc<Registry>>) -> impl IntoResp
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Failed to encode metrics", "Failed to encode metrics",
) )
.into_response() .into_response();
} }
}; };
......
...@@ -8,36 +8,37 @@ use std::{ ...@@ -8,36 +8,37 @@ use std::{
}; };
use axum::{ use axum::{
Json, Router,
extract::State, extract::State,
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{ response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Response, IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
}, },
routing::{get, post}, routing::{get, post},
Json, Router,
}; };
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{AsyncEngineContextProvider, Context}, pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider, protocols::annotated::AnnotationsProvider,
}; };
use futures::{stream, StreamExt}; use futures::{StreamExt, stream};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{ use super::{
disconnect::{create_connection_monitor, monitor_for_disconnects, ConnectionHandle}, RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
error::HttpError, error::HttpError,
metrics::{Endpoint, ResponseMetricCollector}, metrics::{Endpoint, ResponseMetricCollector},
service_v2, RouteDoc, service_v2,
}; };
use crate::preprocessor::LLMMetricAnnotation; use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator; use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::{ use crate::protocols::openai::{
ParsingOptions,
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse}, responses::{NvCreateResponse, NvResponse},
ParsingOptions,
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::Annotated; use crate::types::Annotated;
...@@ -124,18 +125,17 @@ impl ErrorMessage { ...@@ -124,18 +125,17 @@ impl ErrorMessage {
// First check for PipelineError::ServiceOverloaded // First check for PipelineError::ServiceOverloaded
if let Some(pipeline_err) = if let Some(pipeline_err) =
err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>() err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
{ && matches!(
if matches!(
pipeline_err, pipeline_err,
dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_) dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
) { )
return ( {
StatusCode::SERVICE_UNAVAILABLE, return (
Json(ErrorMessage { StatusCode::SERVICE_UNAVAILABLE,
error: pipeline_err.to_string(), Json(ErrorMessage {
}), error: pipeline_err.to_string(),
); }),
} );
} }
// Then check for HttpError // Then check for HttpError
...@@ -166,17 +166,17 @@ impl From<HttpError> for ErrorMessage { ...@@ -166,17 +166,17 @@ impl From<HttpError> for ErrorMessage {
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present /// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String { fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
// Try to get request id from trace context // Try to get request id from trace context
if let Some(trace_context) = get_distributed_tracing_context() { if let Some(trace_context) = get_distributed_tracing_context()
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id { && let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id
return x_dynamo_request_id; {
} return x_dynamo_request_id;
} }
// Try to get the request ID from the primary source // Try to get the request ID from the primary source
if let Some(primary) = primary { if let Some(primary) = primary
if let Ok(uuid) = uuid::Uuid::parse_str(primary) { && let Ok(uuid) = uuid::Uuid::parse_str(primary)
return uuid.to_string(); {
} return uuid.to_string();
} }
// Try to get the request ID header as a string slice // Try to get the request ID header as a string slice
...@@ -792,7 +792,9 @@ pub fn validate_response_input_is_text_only( ...@@ -792,7 +792,9 @@ pub fn validate_response_input_is_text_only(
) -> Option<impl IntoResponse> { ) -> Option<impl IntoResponse> {
match &request.inner.input { match &request.inner.input {
dynamo_async_openai::types::responses::Input::Text(_) => None, dynamo_async_openai::types::responses::Input::Text(_) => None,
_ => Some(ErrorMessage::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")), _ => Some(ErrorMessage::not_implemented_error(
"Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.",
)),
} }
} }
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::env::var; use std::env::var;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use super::metrics;
use super::Metrics; use super::Metrics;
use super::RouteDoc; use super::RouteDoc;
use super::metrics;
use crate::discovery::ModelManager; use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType; use crate::endpoint_type::EndpointType;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
......
...@@ -9,8 +9,8 @@ use anyhow::Result; ...@@ -9,8 +9,8 @@ use anyhow::Result;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, InstanceSource}, component::{Component, InstanceSource},
pipeline::{ pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
ResponseStream, SingleIn, SingleIn, async_trait,
}, },
prelude::*, prelude::*,
protocols::annotated::Annotated, protocols::annotated::Annotated,
...@@ -30,12 +30,12 @@ pub mod scoring; ...@@ -30,12 +30,12 @@ pub mod scoring;
pub mod sequence; pub mod sequence;
use crate::{ use crate::{
discovery::{ModelEntry, MODEL_ROOT_PATH}, discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface, KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
KvRouterError, OverlapScores, RouterEvent, compute_block_hash_for_seq, compute_seq_hash_for_block,
}, },
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
......
...@@ -25,15 +25,15 @@ use tokio_util::sync::CancellationToken; ...@@ -25,15 +25,15 @@ use tokio_util::sync::CancellationToken;
use crate::tokens::{SequenceHash, TokenBlockSequence}; use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, WorkerId,
RadixTree, WorkerId, compute_block_hash_for_seq,
}; };
use crate::kv_router::protocols::{ use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash,
}; };
use crate::kv_router::RouterEvent;
#[derive(Debug)] #[derive(Debug)]
struct MatchRequest { struct MatchRequest {
......
...@@ -1382,10 +1382,11 @@ mod tests { ...@@ -1382,10 +1382,11 @@ mod tests {
let worker_0 = 0; let worker_0 = 0;
let worker_1 = 1; let worker_1 = 1;
assert!(trie assert!(
.find_matches(vec![LocalBlockHash(0)], false) trie.find_matches(vec![LocalBlockHash(0)], false)
.scores .scores
.is_empty()); .is_empty()
);
trie.apply_event(create_store_event(worker_0, 0, vec![0], None)); trie.apply_event(create_store_event(worker_0, 0, vec![0], None));
trie.apply_event(create_store_event(worker_1, 0, vec![0], None)); trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
...@@ -1406,10 +1407,11 @@ mod tests { ...@@ -1406,10 +1407,11 @@ mod tests {
let worker_0 = 0; let worker_0 = 0;
let worker_1 = 1; let worker_1 = 1;
assert!(trie assert!(
.find_matches(vec![LocalBlockHash(0)], false) trie.find_matches(vec![LocalBlockHash(0)], false)
.scores .scores
.is_empty()); .is_empty()
);
// Test clearing an empty worker // Test clearing an empty worker
trie.clear_all_blocks(worker_0); trie.clear_all_blocks(worker_0);
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
use std::sync::Once; use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints; use crate::kv_router::ProcessedEndpoints;
use crate::kv_router::scoring::Endpoint;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use dynamo_runtime::{Result, service::EndpointInfo, utils::Duration};
use tokio::sync::watch; use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
......
...@@ -14,21 +14,21 @@ ...@@ -14,21 +14,21 @@
// limitations under the License. // limitations under the License.
use crate::kv_router::{ use crate::kv_router::{
indexer::{compute_block_hash_for_seq, RouterEvent}, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*, protocols::*,
scoring::LoadEvent, scoring::LoadEvent,
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider}; use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::{ use dynamo_runtime::{
Error, Result,
component::{Component, Namespace}, component::{Component, Namespace},
pipeline::{ pipeline::{
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
SingleIn, network::Ingress,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
Error, Result,
}; };
use futures::stream; use futures::stream;
use std::sync::Arc; use std::sync::Arc;
......
...@@ -11,12 +11,12 @@ use std::sync::Arc; ...@@ -11,12 +11,12 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::watch; use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
use super::WorkerSelector;
use super::indexer::OverlapScores; use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult; use super::protocols::WorkerSelectionResult;
use super::sequence::ActiveSequencesMultiWorker; use super::sequence::ActiveSequencesMultiWorker;
use super::KvRouterConfig;
use super::WorkerSelector;
use super::KV_HIT_RATE_SUBJECT;
use crate::tokens::SequenceHash; use crate::tokens::SequenceHash;
...@@ -293,7 +293,7 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 { ...@@ -293,7 +293,7 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
// Collect all keys with the minimum logit value (to handle ties) // Collect all keys with the minimum logit value (to handle ties)
let min_keys: Vec<_> = logits let min_keys: Vec<_> = logits
.iter() .iter()
.filter(|(_, &v)| v == min_logit) .filter(|&(_, &v)| v == min_logit)
.map(|(k, _)| *k) .map(|(k, _)| *k)
.collect(); .collect();
......
...@@ -29,8 +29,8 @@ use anyhow::Result; ...@@ -29,8 +29,8 @@ use anyhow::Result;
use dashmap::DashMap; use dashmap::DashMap;
use derive_getters::Getters; use derive_getters::Getters;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt; use futures::StreamExt;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
...@@ -428,21 +428,21 @@ impl ActiveSequencesMultiWorker { ...@@ -428,21 +428,21 @@ impl ActiveSequencesMultiWorker {
} }
} }
ActiveSequenceEventData::Free => { ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) { if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
if let Some(sender) = senders.get(&worker_id) { && let Some(sender) = senders.get(&worker_id)
let _ = sender.send(UpdateSequences::Free { {
request_id: event.request_id.clone(), let _ = sender.send(UpdateSequences::Free {
}); request_id: event.request_id.clone(),
} });
} }
} }
ActiveSequenceEventData::MarkPrefillCompleted => { ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id) { if let Some(worker_id) = request_to_worker.get(&event.request_id)
if let Some(sender) = senders.get(&*worker_id) { && let Some(sender) = senders.get(&*worker_id)
let _ = sender.send(UpdateSequences::MarkPrefillCompleted { {
request_id: event.request_id.clone(), let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
}); request_id: event.request_id.clone(),
} });
} }
} }
} }
......
...@@ -238,9 +238,10 @@ mod file_json_field_tests { ...@@ -238,9 +238,10 @@ mod file_json_field_tests {
let result: anyhow::Result<String> = file_json_field(&file_path, "non_existent_field"); let result: anyhow::Result<String> = file_json_field(&file_path, "non_existent_field");
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err(); let err = result.unwrap_err();
assert!(err assert!(
.to_string() err.to_string()
.contains("Field 'non_existent_field' not found")); .contains("Field 'non_existent_field' not found")
);
} }
#[test] #[test]
...@@ -255,9 +256,10 @@ mod file_json_field_tests { ...@@ -255,9 +256,10 @@ mod file_json_field_tests {
let result: anyhow::Result<u32> = file_json_field(&file_path, "count"); let result: anyhow::Result<u32> = file_json_field(&file_path, "count");
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err(); let err = result.unwrap_err();
assert!(err assert!(
.to_string() err.to_string()
.contains("Failed to deserialize field 'count'")); .contains("Failed to deserialize field 'count'")
);
} }
#[test] #[test]
......
...@@ -263,17 +263,15 @@ impl LocalModelBuilder { ...@@ -263,17 +263,15 @@ impl LocalModelBuilder {
} }
// Override runtime configs with mocker engine args // Override runtime configs with mocker engine args
if self.is_mocker { if self.is_mocker
if let Some(path) = &self.extra_engine_args { && let Some(path) = &self.extra_engine_args
let mocker_engine_args = MockEngineArgs::from_json_file(path) {
.expect("Failed to load mocker engine args for runtime config overriding."); let mocker_engine_args = MockEngineArgs::from_json_file(path)
self.runtime_config.total_kv_blocks = .expect("Failed to load mocker engine args for runtime config overriding.");
Some(mocker_engine_args.num_gpu_blocks as u64); self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
mocker_engine_args.max_num_seqs.map(|v| v as u64); self.runtime_config.max_num_batched_tokens =
self.runtime_config.max_num_batched_tokens = mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
} }
card.migration_limit = self.migration_limit; card.migration_limit = self.migration_limit;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig { pub struct ModelRuntimeConfig {
......
...@@ -17,8 +17,8 @@ use crate::{ ...@@ -17,8 +17,8 @@ use crate::{
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
ServerStreamingEngine, SingleIn, SingleIn, async_trait,
}, },
protocols::{annotated::Annotated, maybe_error::MaybeError}, protocols::{annotated::Annotated, maybe_error::MaybeError},
}; };
...@@ -126,13 +126,12 @@ impl RetryManager { ...@@ -126,13 +126,12 @@ impl RetryManager {
// TODO: Is there anything needed to pass between context? // TODO: Is there anything needed to pass between context?
let request = SingleIn::new(self.request.clone()); let request = SingleIn::new(self.request.clone());
response_stream = Some(self.next_generate.generate(request).await); response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() { if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() { && let Some(req_err) = err.downcast_ref::<NatsRequestError>()
if matches!(req_err.kind(), NatsNoResponders) { && matches!(req_err.kind(), NatsNoResponders)
tracing::warn!("Creating new stream... retrying..."); {
continue; tracing::warn!("Creating new stream... retrying...");
} continue;
}
} }
break; break;
} }
...@@ -170,8 +169,8 @@ impl RetryManager { ...@@ -170,8 +169,8 @@ impl RetryManager {
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::pipeline::context::Controller;
use dynamo_runtime::pipeline::AsyncEngine; use dynamo_runtime::pipeline::AsyncEngine;
use dynamo_runtime::pipeline::context::Controller;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc; use tokio::sync::mpsc;
...@@ -624,9 +623,11 @@ mod tests { ...@@ -624,9 +623,11 @@ mod tests {
let error_response = &responses[3]; let error_response = &responses[3];
assert!(error_response.err().is_some()); assert!(error_response.err().is_some());
if let Some(error) = error_response.err() { if let Some(error) = error_response.err() {
assert!(error assert!(
.to_string() error
.contains("Stream ended before generation completed")); .to_string()
.contains("Stream ended before generation completed")
);
} }
} }
...@@ -672,9 +673,11 @@ mod tests { ...@@ -672,9 +673,11 @@ mod tests {
let error_response = &responses[3]; let error_response = &responses[3];
assert!(error_response.err().is_some()); assert!(error_response.err().is_some());
if let Some(error) = error_response.err() { if let Some(error) = error_response.err() {
assert!(error assert!(
.to_string() error
.contains("Stream ended before generation completed")); .to_string()
.contains("Stream ended before generation completed")
);
} }
} }
} }
...@@ -22,18 +22,18 @@ use crate::kv_router::publisher::WorkerMetricsPublisher; ...@@ -22,18 +22,18 @@ use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest; use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal}; use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
use crate::mocker::scheduler::Scheduler; use crate::mocker::scheduler::Scheduler;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::annotated::Annotated; use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_runtime::{ use dynamo_runtime::{
Result,
component::Component, component::Component,
engine::AsyncEngineContextProvider, engine::AsyncEngineContextProvider,
pipeline::{async_trait, AsyncEngine, Error, ManyOut, ResponseStream, SingleIn}, pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
Result,
}; };
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData}; use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
...@@ -43,7 +43,7 @@ use rand::Rng; ...@@ -43,7 +43,7 @@ use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{mpsc, Mutex, OnceCell}; use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid; use uuid::Uuid;
...@@ -523,14 +523,14 @@ pub async fn make_mocker_engine( ...@@ -523,14 +523,14 @@ pub async fn make_mocker_engine(
#[cfg(test)] #[cfg(test)]
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use crate::kv_router::indexer::RouterEvent;
use crate::kv_router::KV_EVENT_SUBJECT; use crate::kv_router::KV_EVENT_SUBJECT;
use crate::kv_router::indexer::RouterEvent;
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::{ use dynamo_runtime::{
DistributedRuntime, Worker,
pipeline::Context, pipeline::Context,
pipeline::{network::Ingress, PushRouter}, pipeline::{PushRouter, network::Ingress},
traits::events::EventSubscriber, traits::events::EventSubscriber,
DistributedRuntime, Worker,
}; };
use futures::StreamExt; use futures::StreamExt;
use tokio::time::timeout; use tokio::time::timeout;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment