Unverified Commit bce74588 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Rust to 1.89 and edition 2024 (#2659)

parent 268d017e
......@@ -29,8 +29,8 @@ use std::collections::HashMap;
use anyhow::Context;
use candle_core::{
quantized::gguf_file::{self, Value},
Result,
quantized::gguf_file::{self, Value},
};
use tracing::info;
......@@ -66,7 +66,9 @@ impl Content {
accum
});
if n_splits.len() > 1 {
candle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?");
candle_core::bail!(
"GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?"
);
}
#[allow(clippy::cast_possible_truncation)]
if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
......
......@@ -26,8 +26,8 @@
// SOFTWARE.
use akin::akin;
use anyhow::ensure;
use anyhow::Result;
use anyhow::ensure;
use candle_core::quantized::gguf_file;
use std::collections::HashMap;
use tracing::warn;
......
......@@ -31,6 +31,7 @@ use ahash::AHashMap;
use anyhow::Result;
use itertools::Itertools;
use tokenizers::{
AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
decoders::{
self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
},
......@@ -41,7 +42,6 @@ use tokenizers::{
self,
template::{self, TemplateProcessing},
},
AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
};
use tracing::info;
......@@ -402,7 +402,7 @@ impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
#[cfg(test)]
mod tests {
use anyhow::Result;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
use tokenizers::Tokenizer;
#[allow(dead_code)]
......
......@@ -14,7 +14,7 @@ use std::time::Instant;
use async_trait::async_trait;
use derive_getters::Dissolve;
use dynamo_async_openai::{config::OpenAIConfig, error::OpenAIError, Client};
use dynamo_async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
use futures::Stream;
use serde_json::Value;
use tokio_util::sync::CancellationToken;
......@@ -22,10 +22,10 @@ use tracing;
use uuid::Uuid;
// Import our existing recording infrastructure
use crate::protocols::Annotated;
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::protocols::Annotated;
use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
};
......@@ -523,7 +523,7 @@ impl GenericBYOTClient {
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{sleep, Duration};
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_http_request_context_creation() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::{service_v2, RouteDoc};
use axum::{http::Method, http::StatusCode, response::IntoResponse, routing::get, Json, Router};
use super::{RouteDoc, service_v2};
use axum::{Json, Router, http::Method, http::StatusCode, response::IntoResponse, routing::get};
use dynamo_runtime::instances::list_all_instances;
use serde_json::json;
use std::sync::Arc;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router};
use axum::{Router, extract::State, http::StatusCode, response::IntoResponse, routing::get};
use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts};
use std::{
sync::Arc,
......@@ -538,7 +538,7 @@ async fn handler_metrics(State(registry): State<Arc<Registry>>) -> impl IntoResp
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to encode metrics",
)
.into_response()
.into_response();
}
};
......
......@@ -8,36 +8,37 @@ use std::{
};
use axum::{
Json, Router,
extract::State,
http::{HeaderMap, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
Json, Router,
};
use dynamo_runtime::{
pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider,
};
use futures::{stream, StreamExt};
use futures::{StreamExt, stream};
use serde::{Deserialize, Serialize};
use super::{
disconnect::{create_connection_monitor, monitor_for_disconnects, ConnectionHandle},
RouteDoc,
disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
error::HttpError,
metrics::{Endpoint, ResponseMetricCollector},
service_v2, RouteDoc,
service_v2,
};
use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::{
ParsingOptions,
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse},
ParsingOptions,
};
use crate::request_template::RequestTemplate;
use crate::types::Annotated;
......@@ -124,18 +125,17 @@ impl ErrorMessage {
// First check for PipelineError::ServiceOverloaded
if let Some(pipeline_err) =
err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
{
if matches!(
&& matches!(
pipeline_err,
dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
) {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
error: pipeline_err.to_string(),
}),
);
}
)
{
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
error: pipeline_err.to_string(),
}),
);
}
// Then check for HttpError
......@@ -166,17 +166,17 @@ impl From<HttpError> for ErrorMessage {
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
// Try to get request id from trace context
if let Some(trace_context) = get_distributed_tracing_context() {
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
return x_dynamo_request_id;
}
if let Some(trace_context) = get_distributed_tracing_context()
&& let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id
{
return x_dynamo_request_id;
}
// Try to get the request ID from the primary source
if let Some(primary) = primary {
if let Ok(uuid) = uuid::Uuid::parse_str(primary) {
return uuid.to_string();
}
if let Some(primary) = primary
&& let Ok(uuid) = uuid::Uuid::parse_str(primary)
{
return uuid.to_string();
}
// Try to get the request ID header as a string slice
......@@ -792,7 +792,9 @@ pub fn validate_response_input_is_text_only(
) -> Option<impl IntoResponse> {
match &request.inner.input {
dynamo_async_openai::types::responses::Input::Text(_) => None,
_ => Some(ErrorMessage::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")),
_ => Some(ErrorMessage::not_implemented_error(
"Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.",
)),
}
}
......
......@@ -4,14 +4,14 @@
use std::collections::HashMap;
use std::env::var;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use super::metrics;
use super::Metrics;
use super::RouteDoc;
use super::metrics;
use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType;
use crate::request_template::RequestTemplate;
......
......@@ -9,8 +9,8 @@ use anyhow::Result;
use dynamo_runtime::{
component::{Component, InstanceSource},
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
ResponseStream, SingleIn,
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
},
prelude::*,
protocols::annotated::Annotated,
......@@ -30,12 +30,12 @@ pub mod scoring;
pub mod sequence;
use crate::{
discovery::{ModelEntry, MODEL_ROOT_PATH},
discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{
approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent,
KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent,
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
......
......@@ -25,15 +25,15 @@ use tokio_util::sync::CancellationToken;
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::{
compute_block_hash_for_seq, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores,
RadixTree, WorkerId,
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, WorkerId,
compute_block_hash_for_seq,
};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::kv_router::RouterEvent;
#[derive(Debug)]
struct MatchRequest {
......
......@@ -1382,10 +1382,11 @@ mod tests {
let worker_0 = 0;
let worker_1 = 1;
assert!(trie
.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty());
assert!(
trie.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty()
);
trie.apply_event(create_store_event(worker_0, 0, vec![0], None));
trie.apply_event(create_store_event(worker_1, 0, vec![0], None));
......@@ -1406,10 +1407,11 @@ mod tests {
let worker_0 = 0;
let worker_1 = 1;
assert!(trie
.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty());
assert!(
trie.find_matches(vec![LocalBlockHash(0)], false)
.scores
.is_empty()
);
// Test clearing an empty worker
trie.clear_all_blocks(worker_0);
......
......@@ -15,13 +15,13 @@
use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints;
use crate::kv_router::scoring::Endpoint;
use dynamo_runtime::component::Component;
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use dynamo_runtime::{Result, service::EndpointInfo, utils::Duration};
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
......
......@@ -14,21 +14,21 @@
// limitations under the License.
use crate::kv_router::{
indexer::{compute_block_hash_for_seq, RouterEvent},
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*,
scoring::LoadEvent,
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT, KV_METRICS_SUBJECT,
};
use async_trait::async_trait;
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::{
Error, Result,
component::{Component, Namespace},
pipeline::{
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream,
SingleIn,
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
network::Ingress,
},
protocols::annotated::Annotated,
Error, Result,
};
use futures::stream;
use std::sync::Arc;
......
......@@ -11,12 +11,12 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
use super::WorkerSelector;
use super::indexer::OverlapScores;
use super::protocols::WorkerSelectionResult;
use super::sequence::ActiveSequencesMultiWorker;
use super::KvRouterConfig;
use super::WorkerSelector;
use super::KV_HIT_RATE_SUBJECT;
use crate::tokens::SequenceHash;
......@@ -293,7 +293,7 @@ fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
// Collect all keys with the minimum logit value (to handle ties)
let min_keys: Vec<_> = logits
.iter()
.filter(|(_, &v)| v == min_logit)
.filter(|&(_, &v)| v == min_logit)
.map(|(k, _)| *k)
.collect();
......
......@@ -29,8 +29,8 @@ use anyhow::Result;
use dashmap::DashMap;
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
......@@ -428,21 +428,21 @@ impl ActiveSequencesMultiWorker {
}
}
ActiveSequenceEventData::Free => {
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id) {
if let Some(sender) = senders.get(&worker_id) {
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
});
}
if let Some((_, worker_id)) = request_to_worker.remove(&event.request_id)
&& let Some(sender) = senders.get(&worker_id)
{
let _ = sender.send(UpdateSequences::Free {
request_id: event.request_id.clone(),
});
}
}
ActiveSequenceEventData::MarkPrefillCompleted => {
if let Some(worker_id) = request_to_worker.get(&event.request_id) {
if let Some(sender) = senders.get(&*worker_id) {
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
});
}
if let Some(worker_id) = request_to_worker.get(&event.request_id)
&& let Some(sender) = senders.get(&*worker_id)
{
let _ = sender.send(UpdateSequences::MarkPrefillCompleted {
request_id: event.request_id.clone(),
});
}
}
}
......
......@@ -238,9 +238,10 @@ mod file_json_field_tests {
let result: anyhow::Result<String> = file_json_field(&file_path, "non_existent_field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Field 'non_existent_field' not found"));
assert!(
err.to_string()
.contains("Field 'non_existent_field' not found")
);
}
#[test]
......@@ -255,9 +256,10 @@ mod file_json_field_tests {
let result: anyhow::Result<u32> = file_json_field(&file_path, "count");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Failed to deserialize field 'count'"));
assert!(
err.to_string()
.contains("Failed to deserialize field 'count'")
);
}
#[test]
......
......@@ -263,17 +263,15 @@ impl LocalModelBuilder {
}
// Override runtime configs with mocker engine args
if self.is_mocker {
if let Some(path) = &self.extra_engine_args {
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks =
Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs =
mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
if self.is_mocker
&& let Some(path) = &self.extra_engine_args
{
let mocker_engine_args = MockEngineArgs::from_json_file(path)
.expect("Failed to load mocker engine args for runtime config overriding.");
self.runtime_config.total_kv_blocks = Some(mocker_engine_args.num_gpu_blocks as u64);
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
}
card.migration_limit = self.migration_limit;
......
......@@ -3,7 +3,7 @@
use std::collections::HashMap;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelRuntimeConfig {
......
......@@ -17,8 +17,8 @@ use crate::{
use dynamo_runtime::{
pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn,
AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
......@@ -126,13 +126,12 @@ impl RetryManager {
// TODO: Is there anything needed to pass between context?
let request = SingleIn::new(self.request.clone());
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) {
tracing::warn!("Creating new stream... retrying...");
continue;
}
}
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
&& matches!(req_err.kind(), NatsNoResponders)
{
tracing::warn!("Creating new stream... retrying...");
continue;
}
break;
}
......@@ -170,8 +169,8 @@ impl RetryManager {
mod tests {
use super::*;
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::pipeline::context::Controller;
use dynamo_runtime::pipeline::AsyncEngine;
use dynamo_runtime::pipeline::context::Controller;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc;
......@@ -624,9 +623,11 @@ mod tests {
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(error
.to_string()
.contains("Stream ended before generation completed"));
assert!(
error
.to_string()
.contains("Stream ended before generation completed")
);
}
}
......@@ -672,9 +673,11 @@ mod tests {
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(error
.to_string()
.contains("Stream ended before generation completed"));
assert!(
error
.to_string()
.contains("Stream ended before generation completed")
);
}
}
}
......@@ -22,18 +22,18 @@ use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::annotated::Annotated;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
Result,
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{async_trait, AsyncEngine, Error, ManyOut, ResponseStream, SingleIn},
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
Result,
};
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
......@@ -43,7 +43,7 @@ use rand::Rng;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, OnceCell};
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
......@@ -523,14 +523,14 @@ pub async fn make_mocker_engine(
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::kv_router::indexer::RouterEvent;
use crate::kv_router::KV_EVENT_SUBJECT;
use crate::kv_router::indexer::RouterEvent;
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::{
DistributedRuntime, Worker,
pipeline::Context,
pipeline::{network::Ingress, PushRouter},
pipeline::{PushRouter, network::Ingress},
traits::events::EventSubscriber,
DistributedRuntime, Worker,
};
use futures::StreamExt;
use tokio::time::timeout;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment