Unverified Commit cf630bf7 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Make the Runtime and DistributedRuntime fields private (#4193)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 0e623146
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::context::{Context, callable_accepts_kwarg};
use dynamo_runtime::logging::get_distributed_tracing_context;
use std::sync::Arc;
use anyhow::{Error, Result};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
use std::sync::Arc;
pub use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use tokio_util::sync::CancellationToken;
use dynamo_runtime::logging::get_distributed_tracing_context;
pub use dynamo_runtime::{
CancellationToken, Error, Result,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream, SingleIn,
async_trait,
},
pipeline::{AsyncEngine, AsyncEngineContextProvider, Data, ManyOut, ResponseStream, SingleIn},
protocols::annotated::Annotated,
};
pub use serde::{Deserialize, Serialize};
use super::context::{Context, callable_accepts_kwarg};
/// Add bingings from this crate to the provided module
pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
......@@ -87,7 +87,7 @@ impl PythonAsyncEngine {
}
}
#[async_trait]
#[async_trait::async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for PythonAsyncEngine
where
Req: Data + Serialize,
......@@ -141,7 +141,7 @@ enum ResponseProcessingError {
OffloadError(String),
}
#[async_trait]
#[async_trait::async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error>
for PythonServerStreamingEngine
where
......
......@@ -3,6 +3,7 @@
use std::sync::Arc;
use anyhow::{Error, Result, anyhow as error};
use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, to_pyerr};
......@@ -10,7 +11,6 @@ use crate::{CancellationToken, engine::*, to_pyerr};
pub use dynamo_llm::endpoint_type::EndpointType;
pub use dynamo_llm::http::service::{error as http_error, service_v2};
pub use dynamo_runtime::{
Error, Result, error,
pipeline::{AsyncEngine, Data, ManyOut, SingleIn, async_trait},
protocols::annotated::Annotated,
};
......
......@@ -441,20 +441,20 @@ impl DistributedRuntime {
// Try to get existing runtime first, create new Worker only if needed
// This allows multiple DistributedRuntime instances to share the same tokio runtime
let runtime = rs::Worker::runtime_from_existing()
.or_else(|_| {
.or_else(|_| -> anyhow::Result<rs::Runtime> {
// No existing Worker, create new one
let worker = rs::Worker::from_settings()?;
// Initialize pyo3 bridge (only happens once per process)
INIT.get_or_try_init(|| {
INIT.get_or_try_init(|| -> anyhow::Result<()> {
let primary = worker.tokio_runtime()?;
pyo3_async_runtimes::tokio::init_with_runtime(primary).map_err(|e| {
rs::error!("failed to initialize pyo3 static runtime: {:?}", e)
anyhow::anyhow!("failed to initialize pyo3 static runtime: {:?}", e)
})?;
rs::OK(())
Ok(())
})?;
rs::OK(worker.runtime().clone())
Ok(worker.runtime().clone())
})
.map_err(to_pyerr)?;
......
......@@ -4,6 +4,7 @@
pub mod managed;
pub use managed::ManagedBlockPool;
use anyhow::Result;
use derive_builder::Builder;
use derive_getters::Dissolve;
use serde::{Deserialize, Serialize};
......@@ -31,8 +32,6 @@ use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result;
// Type aliases to reduce complexity across the module
type BlockPoolResult<T> = Result<T, BlockPoolError>;
type AsyncResponse<T> = Result<oneshot::Receiver<T>, BlockPoolError>;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*,
scoring::LoadEvent,
};
use dynamo_runtime::metrics::{MetricsHierarchy, prometheus_names::kvstats};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::{
Result,
component::{Component, Namespace},
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
};
use std::fmt;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, OnceLock};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use std::time::Duration;
use anyhow::Result;
use rmp_serde as rmps;
use serde::Deserialize;
use serde::Serialize;
use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use std::fmt;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::metrics::{MetricsHierarchy, prometheus_names::kvstats};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
};
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*,
scoring::LoadEvent,
};
// -------------------------------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
......@@ -1025,7 +1027,7 @@ mod tests_startup_helpers {
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl serde::Serialize + Send + Sync),
) -> dynamo_runtime::Result<()> {
) -> anyhow::Result<()> {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
......@@ -1038,7 +1040,7 @@ mod tests_startup_helpers {
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> dynamo_runtime::Result<()> {
) -> anyhow::Result<()> {
self.published
.lock()
.unwrap()
......
......@@ -6,31 +6,33 @@
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
Result,
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};
use futures::StreamExt;
use rand::Rng;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
pub const MOCKER_COMPONENT: &str = "mocker";
......
......@@ -300,7 +300,6 @@ mod integration_tests {
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::discovery::DiscoveryQuery;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use std::sync::Arc;
#[tokio::test]
......
......@@ -1040,6 +1040,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
name = "hello_world"
version = "0.6.1"
dependencies = [
"anyhow",
"dynamo-runtime",
]
......@@ -2633,6 +2634,7 @@ dependencies = [
name = "service_metrics"
version = "0.6.1"
dependencies = [
"anyhow",
"dynamo-runtime",
"futures",
"serde",
......
......@@ -20,5 +20,6 @@ repository = "https://github.com/ai-dynamo/dynamo.git"
[workspace.dependencies]
# local or crates.io
anyhow = "1"
dynamo-runtime = { path = "../" }
prometheus = { version = "0.14" }
......@@ -13,3 +13,4 @@ homepage.workspace = true
dynamo-runtime = { workspace = true }
# third-party
anyhow = { workspace = true }
......@@ -2,18 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging, pipeline::PushRouter,
DistributedRuntime, Runtime, Worker, logging, pipeline::PushRouter,
protocols::annotated::Annotated, stream::StreamExt,
};
use hello_world::DEFAULT_NAMESPACE;
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let client = distributed
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging,
DistributedRuntime, Runtime, Worker, logging,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
async_trait, network::Ingress,
......@@ -13,13 +13,13 @@ use dynamo_runtime::{
use hello_world::DEFAULT_NAMESPACE;
use std::sync::Arc;
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
backend(distributed).await
}
......@@ -34,7 +34,10 @@ impl RequestHandler {
#[async_trait]
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for RequestHandler {
async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
async fn generate(
&self,
input: SingleIn<String>,
) -> anyhow::Result<ManyOut<Annotated<String>>> {
let (data, ctx) = input.into_parts();
let chars = data
......@@ -48,7 +51,7 @@ impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for Reques
}
}
async fn backend(runtime: DistributedRuntime) -> Result<()> {
async fn backend(runtime: DistributedRuntime) -> anyhow::Result<()> {
// attach an ingress to an engine
let ingress = Ingress::for_engine(RequestHandler::new())?;
......
......@@ -14,6 +14,7 @@ repository.workspace = true
dynamo-runtime = { workspace = true }
# third-party
anyhow = { workspace = true }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" }
......
......@@ -5,17 +5,17 @@ use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE;
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging, pipeline::PushRouter,
DistributedRuntime, Runtime, Worker, logging, pipeline::PushRouter,
protocols::annotated::Annotated, utils::Duration,
};
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
......
......@@ -4,7 +4,7 @@
use service_metrics::{DEFAULT_NAMESPACE, MyStats};
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging,
DistributedRuntime, Runtime, Worker, logging,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
async_trait, network::Ingress,
......@@ -14,13 +14,13 @@ use dynamo_runtime::{
};
use std::sync::Arc;
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
backend(distributed).await
}
......@@ -35,7 +35,10 @@ impl RequestHandler {
#[async_trait]
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for RequestHandler {
async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
async fn generate(
&self,
input: SingleIn<String>,
) -> anyhow::Result<ManyOut<Annotated<String>>> {
let (data, ctx) = input.into_parts();
let chars = data
......@@ -49,7 +52,7 @@ impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for Reques
}
}
async fn backend(runtime: DistributedRuntime) -> Result<()> {
async fn backend(runtime: DistributedRuntime) -> anyhow::Result<()> {
// attach an ingress to an engine
let ingress = Ingress::for_engine(RequestHandler::new())?;
......
......@@ -5,17 +5,17 @@ use futures::StreamExt;
use system_metrics::{DEFAULT_COMPONENT, DEFAULT_ENDPOINT, DEFAULT_NAMESPACE};
use dynamo_runtime::{
DistributedRuntime, Result, Runtime, Worker, logging, pipeline::PushRouter,
DistributedRuntime, Runtime, Worker, logging, pipeline::PushRouter,
protocols::annotated::Annotated, utils::Duration,
};
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{DistributedRuntime, Result, Runtime, Worker, logging};
use dynamo_runtime::{DistributedRuntime, Runtime, Worker, logging};
use system_metrics::backend;
fn main() -> Result<()> {
fn main() -> anyhow::Result<()> {
logging::init();
let worker = Worker::from_settings()?;
worker.execute(app)
}
async fn app(runtime: Runtime) -> Result<()> {
async fn app(runtime: Runtime) -> anyhow::Result<()> {
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
backend(distributed, None).await
}
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{
DistributedRuntime, Result,
DistributedRuntime,
metrics::MetricsHierarchy,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn,
......@@ -64,7 +64,10 @@ impl RequestHandler {
#[async_trait]
impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for RequestHandler {
async fn generate(&self, input: SingleIn<String>) -> Result<ManyOut<Annotated<String>>> {
async fn generate(
&self,
input: SingleIn<String>,
) -> anyhow::Result<ManyOut<Annotated<String>>> {
let (data, ctx) = input.into_parts();
// Track data bytes processed if metrics are available
......@@ -85,7 +88,7 @@ impl AsyncEngine<SingleIn<String>, ManyOut<Annotated<String>>, Error> for Reques
/// Backend function that sets up the system status server with metrics and ingress handler
/// This function can be reused by integration tests to ensure they use the exact same setup
pub async fn backend(drt: DistributedRuntime, endpoint_name: Option<&str>) -> Result<()> {
pub async fn backend(drt: DistributedRuntime, endpoint_name: Option<&str>) -> anyhow::Result<()> {
let endpoint_name = endpoint_name.unwrap_or(DEFAULT_ENDPOINT);
let mut component = drt
......
......@@ -39,7 +39,7 @@ use crate::{
};
use super::{
DistributedRuntime, Result, Runtime, error,
DistributedRuntime, Runtime,
traits::*,
transports::etcd::{COMPONENT_KEYWORD, ENDPOINT_KEYWORD},
transports::nats::Slug,
......@@ -302,7 +302,7 @@ impl Component {
/// Scrape ServiceSet, which contains NATS stats as well as user defined stats
/// embedded in data field of ServiceInfo.
pub async fn scrape_stats(&self, timeout: Duration) -> Result<ServiceSet> {
pub async fn scrape_stats(&self, timeout: Duration) -> anyhow::Result<ServiceSet> {
// Debug: scraping stats for component
let service_name = self.service_name();
let Some(service_client) = self.drt().service_client() else {
......@@ -320,7 +320,7 @@ impl Component {
/// then subsequent scrapes occur at a fixed interval of 9.8 seconds (MAX_WAIT_MS),
/// which should be near or smaller than typical Prometheus scraping intervals to ensure
/// metrics are fresh when Prometheus collects them.
pub fn start_scraping_nats_service_component_metrics(&self) -> Result<()> {
pub fn start_scraping_nats_service_component_metrics(&self) -> anyhow::Result<()> {
const MAX_WAIT_MS: std::time::Duration = std::time::Duration::from_millis(9800); // Should be <= Prometheus scrape interval
// If there is another component with the same service name, this will fail.
......@@ -373,7 +373,7 @@ impl Component {
/// Returns a stream of `ServiceInfo` objects.
/// This should be consumed by a `[tokio::time::timeout_at`] because each services
/// will only respond once, but there is no way to know when all services have responded.
pub async fn stats_stream(&self) -> Result<()> {
pub async fn stats_stream(&self) -> anyhow::Result<()> {
unimplemented!("collect_stats")
}
......@@ -383,7 +383,7 @@ impl Component {
// Pre-check to save cost of creating the service, but don't hold the lock
if self
.drt
.component_registry
.component_registry()
.inner
.lock()
.await
......@@ -400,7 +400,7 @@ impl Component {
let (nats_service, stats_reg) =
service::build_nats_service(nats_client, self, description).await?;
let mut guard = self.drt.component_registry.inner.lock().await;
let mut guard = self.drt.component_registry().inner.lock().await;
if !guard.services.contains_key(&service_name) {
// Normal case
guard.services.insert(service_name.clone(), nats_service);
......@@ -590,7 +590,7 @@ impl Endpoint {
)
}
pub async fn client(&self) -> Result<client::Client> {
pub async fn client(&self) -> anyhow::Result<client::Client> {
if self.is_static {
client::Client::new_static(self.clone()).await
} else {
......@@ -655,7 +655,11 @@ impl std::fmt::Display for Namespace {
}
impl Namespace {
pub(crate) fn new(runtime: DistributedRuntime, name: String, is_static: bool) -> Result<Self> {
pub(crate) fn new(
runtime: DistributedRuntime,
name: String,
is_static: bool,
) -> anyhow::Result<Self> {
Ok(NamespaceBuilder::default()
.runtime(Arc::new(runtime))
.name(name)
......@@ -664,7 +668,7 @@ impl Namespace {
}
/// Create a [`Component`] in the namespace who's endpoints can be discovered with etcd
pub fn component(&self, name: impl Into<String>) -> Result<Component> {
pub fn component(&self, name: impl Into<String>) -> anyhow::Result<Component> {
Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name)
.namespace(self.clone())
......@@ -673,7 +677,7 @@ impl Namespace {
}
/// Create a [`Namespace`] in the parent namespace
pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
pub fn namespace(&self, name: impl Into<String>) -> anyhow::Result<Namespace> {
Ok(NamespaceBuilder::default()
.runtime(self.runtime.clone())
.name(name.into())
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use anyhow::Result;
use arc_swap::ArcSwap;
use futures::StreamExt;
use tokio::net::unix::pipe::Receiver;
use crate::{
component::{Endpoint, Instance},
pipeline::async_trait,
pipeline::{
AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
SingleIn,
},
storage::key_value_store::{KeyValueStoreManager, WatchEvent},
traits::DistributedRuntimeProvider,
transports::etcd::Client as EtcdClient,
};
use arc_swap::ArcSwap;
use futures::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::unix::pipe::Receiver;
use crate::{pipeline::async_trait, transports::etcd::Client as EtcdClient};
use super::*;
/// Each state will be have a nonce associated with it
/// The state will be emitted in a watch channel, so we can observe the
......@@ -318,7 +321,7 @@ impl Client {
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let secondary = endpoint.component.drt.runtime.secondary().clone();
let secondary = endpoint.component.drt.runtime().secondary().clone();
secondary.spawn(async move {
tracing::debug!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment