Unverified Commit bce74588 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Rust to 1.89 and edition 2024 (#2659)

parent 268d017e
......@@ -46,10 +46,10 @@ pub mod llm_kvbm {
},
};
use dynamo_llm::tokens::{BlockHash, SequenceHash};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::component::Namespace;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use dynamo_runtime::DistributedRuntime;
use kvbm::events::EventManager;
use tokio::sync::mpsc;
pub use tokio_util::sync::CancellationToken;
......@@ -383,18 +383,18 @@ mod tests {
use dynamo_llm::tokens::{TokenBlockSequence, Tokens};
use dynamo_runtime::{
traits::events::{EventPublisher, EventSubscriber},
DistributedRuntime, Runtime,
traits::events::{EventPublisher, EventSubscriber},
};
use kvbm::{
block::registry::BlockRegistry,
block::state::CompleteState,
KvBlockManagerConfig, KvManagerLayoutConfig, KvManagerModelConfig, NixlOptions,
ReferenceBlockManager,
block::BlockState,
block::GlobalRegistry,
block::registry::BlockRegistry,
block::state::CompleteState,
events::EventManager,
storage::{DeviceAllocator, DiskAllocator, PinnedAllocator},
KvBlockManagerConfig, KvManagerLayoutConfig, KvManagerModelConfig, NixlOptions,
ReferenceBlockManager,
};
use dynamo_llm::kv_router::{
......
......@@ -21,30 +21,30 @@ use dynamo_llm::http::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
},
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, RequestType, Status, FRONTEND_METRIC_PREFIX},
metrics::{Endpoint, FRONTEND_METRIC_PREFIX, RequestType, Status},
service_v2::HttpService,
Metrics,
},
};
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
Annotated,
};
use dynamo_runtime::{
CancellationToken,
engine::AsyncEngineContext,
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn,
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
},
CancellationToken,
};
use futures::StreamExt;
use prometheus::{proto::MetricType, Registry};
use prometheus::{Registry, proto::MetricType};
use reqwest::StatusCode;
use std::{io::Cursor, sync::Arc};
use tokio::time::timeout;
......@@ -1232,23 +1232,23 @@ async fn test_request_id_annotation() {
let mut annotated_stream = std::pin::pin!(annotated_stream);
while let Some(annotated_response) = annotated_stream.next().await {
// Check if this is a request_id annotation
if let Some(event) = &annotated_response.event {
if event == "request_id" {
found_request_id_annotation = true;
// Extract the request ID from the annotation
if let Some(comments) = &annotated_response.comment {
if let Some(comment) = comments.first() {
// The comment contains a JSON-encoded string, so we need to parse it
if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
received_request_id = Some(parsed_value);
} else {
// Fallback: remove quotes manually if JSON parsing fails
received_request_id = Some(comment.trim_matches('"').to_string());
}
}
if let Some(event) = &annotated_response.event
&& event == "request_id"
{
found_request_id_annotation = true;
// Extract the request ID from the annotation
if let Some(comments) = &annotated_response.comment
&& let Some(comment) = comments.first()
{
// The comment contains a JSON-encoded string, so we need to parse it
if let Ok(parsed_value) = serde_json::from_str::<String>(comment) {
received_request_id = Some(parsed_value);
} else {
// Fallback: remove quotes manually if JSON parsing fails
received_request_id = Some(comment.trim_matches('"').to_string());
}
break;
}
break;
}
}
......
......@@ -15,7 +15,7 @@ use ports::get_random_port;
#[serial]
async fn metrics_prefix_default_then_env_override() {
// Case 1: default prefix
env::remove_var(metrics::METRICS_PREFIX_ENV);
unsafe { env::remove_var(metrics::METRICS_PREFIX_ENV) };
let p1 = get_random_port().await;
let svc1 = HttpService::builder().port(p1).build().unwrap();
let token1 = CancellationToken::new();
......@@ -42,7 +42,7 @@ async fn metrics_prefix_default_then_env_override() {
let _ = h1.await; // ensure port is released
// Case 2: env override to prefix
env::set_var(metrics::METRICS_PREFIX_ENV, "custom_prefix");
unsafe { env::set_var(metrics::METRICS_PREFIX_ENV, "custom_prefix") };
let p2 = get_random_port().await;
let svc2 = HttpService::builder().port(p2).build().unwrap();
let token2 = CancellationToken::new();
......@@ -69,7 +69,7 @@ async fn metrics_prefix_default_then_env_override() {
let _ = h2.await;
// Case 3: invalid env prefix is sanitized
env::set_var(metrics::METRICS_PREFIX_ENV, "nv-llm/http service");
unsafe { env::set_var(metrics::METRICS_PREFIX_ENV, "nv-llm/http service") };
let p3 = get_random_port().await;
let svc3 = HttpService::builder().port(p3).build().unwrap();
let token3 = CancellationToken::new();
......@@ -94,7 +94,7 @@ async fn metrics_prefix_default_then_env_override() {
let _ = h3.await;
// Cleanup env to avoid leaking state
env::remove_var(metrics::METRICS_PREFIX_ENV);
unsafe { env::remove_var(metrics::METRICS_PREFIX_ENV) };
}
// Poll /metrics until ready or timeout
......
......@@ -235,8 +235,8 @@ fn create_multi_choice_stream() -> Arc<RecordedStream<NvCreateChatCompletionStre
// record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
// }
fn create_stream_with_multiple_close_tokens(
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
fn create_stream_with_multiple_close_tokens()
-> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
......
......@@ -8,7 +8,7 @@ use dynamo_llm::preprocessor::prompt::PromptFormatter;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use serde::{Deserialize, Serialize};
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
use hf_hub::{Cache, Repo, RepoType, api::tokio::ApiBuilder};
use std::path::PathBuf;
......
......@@ -132,7 +132,7 @@ fn test_chat_completions_common_overrides_nvext() {
Some(true)
);
assert_eq!(request.get_guided_regex(), Some(".*".to_string())); // common value takes precedence
// Verify precedence through stop conditions extraction
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(false)); // common value takes precedence
assert_eq!(stop_conditions.min_tokens, Some(50));
......
......@@ -8,10 +8,10 @@ pub mod tools;
// Re-export main types and functions for convenience
pub use json_parser::{
try_tool_call_parse_json, CalledFunctionArguments, CalledFunctionParameters,
CalledFunctionArguments, CalledFunctionParameters, try_tool_call_parse_json,
};
pub use parsers::{
detect_and_parse_tool_call, JsonParserConfig, ToolCallConfig, ToolCallParserType,
JsonParserConfig, ToolCallConfig, ToolCallParserType, detect_and_parse_tool_call,
};
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
......@@ -839,8 +839,8 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag_multiple_with_new_lines(
) {
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag_multiple_with_new_lines()
{
let input = r#"
{"name": "get_weather", "arguments":
{"location": "San Francisco, CA",
......
......@@ -5,7 +5,7 @@ pub use super::response::*;
// Import json_parser from postprocessor module
pub use super::json_parser::*;
pub use super::parsers::{detect_and_parse_tool_call, ToolCallConfig};
pub use super::parsers::{ToolCallConfig, detect_and_parse_tool_call};
/// Try parsing a string as a structured tool call, for aggregation usage.
///
......
......@@ -268,6 +268,15 @@ version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
......@@ -658,6 +667,7 @@ dependencies = [
"async-trait",
"async_zmq",
"axum",
"bincode",
"blake3",
"bytes",
"chrono",
......
[toolchain]
channel = "1.87.0"
channel = "1.89.0"
......@@ -32,21 +32,20 @@
use crate::{
config::HealthStatus,
discovery::Lease,
metrics::{prometheus_names, MetricsRegistry},
metrics::{MetricsRegistry, prometheus_names},
service::ServiceSet,
transports::etcd::EtcdPath,
};
use super::{
error,
DistributedRuntime, Result, Runtime, error,
traits::*,
transports::etcd::{COMPONENT_KEYWORD, ENDPOINT_KEYWORD},
transports::nats::Slug,
utils::Duration,
DistributedRuntime, Result, Runtime,
};
use crate::pipeline::network::{ingress::push_endpoint::PushEndpoint, PushWorkHandler};
use crate::pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint};
use crate::protocols::EndpointId;
use crate::service::ComponentNatsServerPrometheusMetrics;
use async_nats::{
......@@ -288,11 +287,13 @@ impl Component {
let component_clone = self.clone();
let mut hierarchies = self.parent_hierarchy();
hierarchies.push(self.hierarchy());
debug_assert!(hierarchies
.last()
.map(|x| x.as_str())
.unwrap_or_default()
.eq_ignore_ascii_case(&self.service_name())); // it happens that in component, hierarchy and service name are the same
debug_assert!(
hierarchies
.last()
.map(|x| x.as_str())
.unwrap_or_default()
.eq_ignore_ascii_case(&self.service_name())
); // it happens that in component, hierarchy and service name are the same
// Start a background task that scrapes stats every 5 seconds
let m = component_metrics.clone();
......
......@@ -148,19 +148,18 @@ impl EndpointConfigBuilder {
let info = serde_json::to_vec_pretty(&info)?;
if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
if let Err(e) = etcd_client
if let Some(etcd_client) = &endpoint.component.drt.etcd_client
&& let Err(e) = etcd_client
.kv_create(
&endpoint.etcd_path_with_lease_id(lease_id),
info,
Some(lease_id),
)
.await
{
tracing::error!("Failed to register discoverable service: {:?}", e);
cancel_token.cancel();
return Err(error!("Failed to register discoverable service"));
}
{
tracing::error!("Failed to register discoverable service: {:?}", e);
cancel_token.cancel();
return Err(error!("Failed to register discoverable service"));
}
task.await??;
......
......@@ -4,8 +4,8 @@
use super::Result;
use derive_builder::Builder;
use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
providers::{Env, Format, Serialized, Toml},
};
use serde::{Deserialize, Serialize};
use std::fmt;
......@@ -371,12 +371,14 @@ mod tests {
let result = RuntimeConfig::from_settings();
assert!(result.is_err());
if let Err(e) = result {
assert!(e
.to_string()
.contains("num_worker_threads: Validation error"));
assert!(e
.to_string()
.contains("max_blocking_threads: Validation error"));
assert!(
e.to_string()
.contains("num_worker_threads: Validation error")
);
assert!(
e.to_string()
.contains("max_blocking_threads: Validation error")
);
}
Ok(())
},
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::{transports::etcd, Result};
use crate::{Result, transports::etcd};
pub use etcd::Lease;
......
......@@ -16,15 +16,15 @@
pub use crate::component::Component;
use crate::transports::nats::DRTNatsClientPrometheusMetrics;
use crate::{
ErrorContext, RuntimeCallback,
component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
discovery::DiscoveryClient,
metrics::MetricsRegistry,
service::ServiceClient,
transports::{etcd, nats, tcp},
ErrorContext, RuntimeCallback,
};
use super::{error, Arc, DistributedRuntime, OnceCell, Result, Runtime, SystemHealth, Weak, OK};
use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
use std::sync::OnceLock;
use derive_getters::Dissolve;
......@@ -164,7 +164,9 @@ impl DistributedRuntime {
tracing::warn!("Failed to initialize system status start time: {}", e);
}
tracing::debug!("System status server HTTP endpoints disabled, but uptime metrics are being tracked");
tracing::debug!(
"System status server HTTP endpoints disabled, but uptime metrics are being tracked"
);
}
Ok(distributed_runtime)
......
......@@ -7,7 +7,7 @@
//! the entire distributed system, complementing the component-specific
//! instance listing in `component.rs`.
use crate::component::{Instance, INSTANCE_ROOT_PATH};
use crate::component::{INSTANCE_ROOT_PATH, Instance};
use crate::transports::etcd::Client as EtcdClient;
pub async fn list_all_instances(etcd_client: &EtcdClient) -> anyhow::Result<Vec<Instance>> {
......
......@@ -25,7 +25,7 @@ use std::{
use tokio::sync::Mutex;
pub use anyhow::{
anyhow as error, bail as raise, Context as ErrorContext, Error, Ok as OK, Result,
Context as ErrorContext, Error, Ok as OK, Result, anyhow as error, bail as raise,
};
use async_once_cell::OnceCell;
......
......@@ -30,43 +30,43 @@ use std::collections::{BTreeMap, HashMap};
use std::sync::Once;
use figment::{
providers::{Format, Serialized, Toml},
Figment,
providers::{Format, Serialized, Toml},
};
use serde::{Deserialize, Serialize};
use tracing::level_filters::LevelFilter;
use tracing::{Event, Subscriber};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::time::FormatTime;
use tracing_subscriber::fmt::time::LocalTime;
use tracing_subscriber::fmt::time::SystemTime;
use tracing_subscriber::fmt::time::UtcTime;
use tracing_subscriber::fmt::{format::Writer, FormattedFields};
use tracing_subscriber::fmt::{FmtContext, FormatFields};
use tracing_subscriber::fmt::{FormattedFields, format::Writer};
use tracing_subscriber::prelude::*;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::{filter::Directive, fmt};
use crate::config::{disable_ansi_logging, jsonl_logging_enabled};
use async_nats::{HeaderMap, HeaderValue};
use axum::extract::FromRequestParts;
use axum::http;
use axum::http::request::Parts;
use axum::http::Request;
use axum::http::request::Parts;
use serde_json::Value;
use std::convert::Infallible;
use std::time::Instant;
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
use tracing::field::Field;
use tracing::span;
use tracing::Id;
use tracing::Span;
use tracing::field::Field;
use tracing::span;
use tracing_subscriber::Layer;
use tracing_subscriber::Registry;
use tracing_subscriber::field::Visit;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::layer::Context;
use tracing_subscriber::registry::SpanData;
use tracing_subscriber::Layer;
use tracing_subscriber::Registry;
use uuid::Uuid;
/// ENV used to set the log level
......@@ -340,18 +340,15 @@ where
x_dynamo_request_id = Some(x_request_id_input.to_string());
}
if parent_id.is_none() {
if let Some(parent_span_id) = ctx.current_span().id() {
if let Some(parent_span) = ctx.span(parent_span_id) {
let parent_ext = parent_span.extensions();
if let Some(parent_tracing_context) =
parent_ext.get::<DistributedTraceContext>()
{
trace_id = Some(parent_tracing_context.trace_id.clone());
parent_id = Some(parent_tracing_context.span_id.clone());
tracestate = parent_tracing_context.tracestate.clone();
}
}
if parent_id.is_none()
&& let Some(parent_span_id) = ctx.current_span().id()
&& let Some(parent_span) = ctx.span(parent_span_id)
{
let parent_ext = parent_span.extensions();
if let Some(parent_tracing_context) = parent_ext.get::<DistributedTraceContext>() {
trace_id = Some(parent_tracing_context.trace_id.clone());
parent_id = Some(parent_tracing_context.span_id.clone());
tracestate = parent_tracing_context.tracestate.clone();
}
}
......@@ -787,7 +784,7 @@ impl tracing::field::Visit for JsonVisitor {
#[cfg(test)]
pub mod tests {
use super::*;
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow};
use chrono::{DateTime, Utc};
use jsonschema::{Draft, JSONSchema};
use serde_json::Value;
......@@ -1009,40 +1006,37 @@ pub mod tests {
// Parent span has no parent_id
for log_line in &lines {
if let Some(span_name) = log_line.get("span_name") {
if let Some(span_name_str) = span_name.as_str() {
if span_name_str == "parent" {
assert!(log_line.get("parent_id").is_none());
}
}
if let Some(span_name) = log_line.get("span_name")
&& let Some(span_name_str) = span_name.as_str()
&& span_name_str == "parent"
{
assert!(log_line.get("parent_id").is_none());
}
}
// Child span's parent_id is parent_span_id
for log_line in &lines {
if let Some(span_name) = log_line.get("span_name") {
if let Some(span_name_str) = span_name.as_str() {
if span_name_str == "child" {
assert_eq!(
log_line.get("parent_id").unwrap().as_str().unwrap(),
&parent_span_id
);
}
}
if let Some(span_name) = log_line.get("span_name")
&& let Some(span_name_str) = span_name.as_str()
&& span_name_str == "child"
{
assert_eq!(
log_line.get("parent_id").unwrap().as_str().unwrap(),
&parent_span_id
);
}
}
// Grandchild span's parent_id is child_span_id
for log_line in &lines {
if let Some(span_name) = log_line.get("span_name") {
if let Some(span_name_str) = span_name.as_str() {
if span_name_str == "grandchild" {
assert_eq!(
log_line.get("parent_id").unwrap().as_str().unwrap(),
&child_span_id
);
}
}
if let Some(span_name) = log_line.get("span_name")
&& let Some(span_name_str) = span_name.as_str()
&& span_name_str == "grandchild"
{
assert_eq!(
log_line.get("parent_id").unwrap().as_str().unwrap(),
&child_span_id
);
}
}
......
......@@ -33,14 +33,14 @@ use std::collections::HashMap;
// Import commonly used items to avoid verbose prefixes
use prometheus_names::{
build_metric_name, labels, name_prefix, nats_client, nats_service, work_handler,
COMPONENT_NATS_METRICS, DRT_NATS_METRICS,
COMPONENT_NATS_METRICS, DRT_NATS_METRICS, build_metric_name, labels, name_prefix, nats_client,
nats_service, work_handler,
};
// Pipeline imports for endpoint creation
use crate::pipeline::{
async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut,
ResponseStream, SingleIn,
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn, async_trait,
network::Ingress,
};
use crate::protocols::annotated::Annotated;
use crate::stream;
......@@ -985,9 +985,11 @@ mod test_metricsregistry_prefixes {
// Valid namespace works
let valid_namespace = drt.namespace("ns567").unwrap();
assert!(valid_namespace
.create_counter("test_counter", "A test counter", &[])
.is_ok());
assert!(
valid_namespace
.create_counter("test_counter", "A test counter", &[])
.is_ok()
);
}
#[tokio::test]
......@@ -1042,8 +1044,8 @@ mod test_metricsregistry_prefixes {
#[cfg(test)]
mod test_metricsregistry_prometheus_fmt_outputs {
use super::prometheus_names::name_prefix;
use super::prometheus_names::{nats_client, nats_service};
use super::prometheus_names::{COMPONENT_NATS_METRICS, DRT_NATS_METRICS};
use super::prometheus_names::{nats_client, nats_service};
use super::*;
use crate::distributed::test_helpers::create_test_drt_async;
use prometheus::Counter;
......@@ -1289,9 +1291,11 @@ dynamo_component_nats_service_total_errors 5"#;
// Test extract_metrics (only actual metric lines, excluding help/type)
let metrics_only = super::test_helpers::extract_metrics(test_input);
assert_eq!(metrics_only.len(), 6); // 6 actual metric lines (excluding help/type)
assert!(metrics_only
.iter()
.all(|line| line.starts_with("dynamo_component") && !line.starts_with("#")));
assert!(
metrics_only
.iter()
.all(|line| line.starts_with("dynamo_component") && !line.starts_with("#"))
);
println!("✓ All refactored filter functions work correctly!");
}
......@@ -1301,13 +1305,13 @@ dynamo_component_nats_service_total_errors 5"#;
#[cfg(test)]
mod test_metricsregistry_nats {
use super::prometheus_names::name_prefix;
use super::prometheus_names::{nats_client, nats_service};
use super::prometheus_names::{COMPONENT_NATS_METRICS, DRT_NATS_METRICS};
use super::prometheus_names::{nats_client, nats_service};
use super::*;
use crate::distributed::test_helpers::create_test_drt_async;
use crate::pipeline::PushRouter;
use crate::{DistributedRuntime, Runtime};
use tokio::time::{sleep, Duration};
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_drt_nats_metrics() {
// Setup real DRT and registry using the test-friendly constructor
......@@ -1361,8 +1365,7 @@ mod test_metricsregistry_nats {
// Compare the sorted lists
assert_eq!(
actual_drt_nats_metrics_sorted,
expect_drt_nats_metrics_sorted,
actual_drt_nats_metrics_sorted, expect_drt_nats_metrics_sorted,
"DRT_NATS_METRICS with prefix and expected_nats_metrics should be identical when sorted"
);
......@@ -1429,8 +1432,7 @@ mod test_metricsregistry_nats {
// Compare the sorted lists
assert_eq!(
actual_component_nats_metrics_sorted,
expect_component_nats_metrics_sorted,
actual_component_nats_metrics_sorted, expect_component_nats_metrics_sorted,
"COMPONENT_NATS_METRICS with prefix and expected_nats_metrics should be identical when sorted"
);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment