Unverified Commit 1065ff1a authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: adding type-erased AnyAsyncEngine (#1601)

parent 9622caaa
......@@ -13,12 +13,80 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc};
//! Asynchronous Engine System with Type Erasure Support
//!
//! This module provides the core asynchronous engine abstraction for Dynamo's runtime system.
//! It defines the `AsyncEngine` trait for streaming engines and provides sophisticated
//! type-erasure capabilities for managing heterogeneous engine collections.
//!
//! ## Type Erasure Overview
//!
//! Type erasure is a critical feature that allows storing different `AsyncEngine` implementations
//! with varying generic type parameters in a single collection (e.g., `HashMap<String, Arc<dyn AnyAsyncEngine>>`).
//! This is essential for:
//!
//! - **Dynamic Engine Management**: Registering and retrieving engines at runtime based on configuration
//! - **Plugin Systems**: Loading different engine implementations without compile-time knowledge
//! - **Service Discovery**: Managing multiple engine types in a unified registry
//!
//! ## Implementation Details
//!
//! The type-erasure system uses several advanced Rust features:
//!
//! - **Trait Objects (`dyn Trait`)**: For runtime polymorphism without compile-time type information
//! - **`std::any::TypeId`**: For runtime type checking during downcasting
//! - **`std::any::Any`**: For type-erased storage and safe downcasting
//! - **`PhantomData`**: For maintaining type relationships in generic wrappers
//! - **Extension Traits**: For ergonomic API design without modifying existing types
//!
//! ## Safety Considerations
//!
//! ⚠️ **IMPORTANT**: The type-erasure system relies on precise type matching at runtime.
//! When modifying these traits or their implementations:
//!
//! - **Never change the type ID logic** in `AnyAsyncEngine` implementations
//! - **Maintain the blanket `Data` implementation** for all `Send + Sync + 'static` types
//! - **Test downcasting thoroughly** when adding new engine types
//! - **Document any changes** that affect the type-erasure behavior
//!
//! ## Usage Example
//!
//! ```rust
//! use std::collections::HashMap;
//! use std::sync::Arc;
//! use crate::engine::{AsyncEngine, AsAnyAsyncEngine, DowncastAnyAsyncEngine};
//!
//! // Create typed engines
//! let string_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyStringEngine::new());
//! let int_engine: Arc<dyn AsyncEngine<i32, i32, ()>> = Arc::new(MyIntEngine::new());
//!
//! // Store in heterogeneous collection
//! let mut engines: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
//! engines.insert("string".to_string(), string_engine.into_any_engine());
//! engines.insert("int".to_string(), int_engine.into_any_engine());
//!
//! // Retrieve and downcast safely
//! if let Some(typed_engine) = engines.get("string").unwrap().downcast::<String, String, ()>() {
//! let result = typed_engine.generate("hello".to_string()).await;
//! }
//! ```
use std::{
any::{Any, TypeId},
fmt::Debug,
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
};
pub use async_trait::async_trait;
use futures::stream::Stream;
/// All [`Send`] + [`Sync`] + `'static` types can be used as [`AsyncEngine`] request and response types.
///
/// This is implemented as a blanket implementation for all types that meet the bounds.
/// **Do not manually implement this trait** - the blanket implementation covers all valid types.
pub trait Data: Send + Sync + 'static {}
impl<T: Send + Sync + 'static> Data for T {}
......@@ -43,6 +111,15 @@ pub trait AsyncEngineController: Send + Sync {}
/// The [`AsyncEngineContext`] trait defines the interface to control the resulting stream
/// produced by the engine.
///
/// This trait provides lifecycle management for async operations, including:
/// - Stream identification via unique IDs
/// - Graceful shutdown capabilities (`stop_generating`)
/// - Immediate termination capabilities (`kill`)
/// - Status checking for stopped/killed states
///
/// Implementations should ensure thread-safety and proper state management
/// across concurrent access patterns.
#[async_trait]
pub trait AsyncEngineContext: Send + Sync + Debug {
/// Unique ID for the Stream
......@@ -84,22 +161,48 @@ pub trait AsyncEngineContext: Send + Sync + Debug {
fn kill(&self);
}
/// Provides access to the [`AsyncEngineContext`] associated with an engine operation.
///
/// This trait is implemented by both unary and streaming engine results, allowing
/// uniform access to context information regardless of the operation type.
pub trait AsyncEngineContextProvider: Send + Sync + Debug {
fn context(&self) -> Arc<dyn AsyncEngineContext>;
}
/// A unary (single-response) asynchronous engine operation.
///
/// This trait combines `Future` semantics with context provider capabilities,
/// representing a single async operation that produces one result.
pub trait AsyncEngineUnary<Resp: Data>:
Future<Output = Resp> + AsyncEngineContextProvider + Send + Sync
{
}
/// A streaming asynchronous engine operation.
///
/// This trait combines `Stream` semantics with context provider capabilities,
/// representing a continuous async operation that produces multiple results over time.
pub trait AsyncEngineStream<Resp: Data>:
Stream<Item = Resp> + AsyncEngineContextProvider + Send + Sync
{
}
/// Engine is a trait that defines the interface for a steaming LLM completion engine.
/// Engine is a trait that defines the interface for a streaming engine.
/// The synchronous Engine version is does not need to be awaited.
///
/// This is the core trait for all async engine implementations. It provides:
/// - Generic type parameters for request, response, and error types
/// - Async generation capabilities with proper error handling
/// - Thread-safe design with `Send + Sync` bounds
///
/// ## Type Parameters
/// - `Req`: The request type that implements `Data`
/// - `Resp`: The response type that implements both `Data` and `AsyncEngineContextProvider`
/// - `E`: The error type that implements `Data`
///
/// ## Implementation Notes
/// Implementations should ensure proper error handling and resource management.
/// The `generate` method should be cancellable via the response's context provider.
#[async_trait]
pub trait AsyncEngine<Req: Data, Resp: Data + AsyncEngineContextProvider, E: Data>:
Send + Sync
......@@ -166,3 +269,247 @@ impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineStream<T>>>
AsyncEngineContextProvider::context(&**self)
}
}
/// A type-erased `AsyncEngine`.
///
/// This trait enables storing heterogeneous `AsyncEngine` implementations in collections
/// by erasing their specific generic type parameters. It provides runtime type information
/// and safe downcasting capabilities.
///
/// ## Type Erasure Mechanism
/// The trait uses `std::any::TypeId` to preserve type information at runtime, allowing
/// safe downcasting back to the original `AsyncEngine<Req, Resp, E>` types.
///
/// ## Safety Guarantees
/// - Type IDs are preserved exactly as they were during type erasure
/// - Downcasting is only possible to the original type combination
/// - Incorrect downcasts return `None` rather than panicking
///
/// ## Implementation Notes
/// This trait is implemented by the internal `AnyEngineWrapper` struct. Users should
/// not implement this trait directly - use the `AsAnyAsyncEngine` extension trait instead.
pub trait AnyAsyncEngine: Send + Sync {
/// Returns the `TypeId` of the request type used by this engine.
fn request_type_id(&self) -> TypeId;
/// Returns the `TypeId` of the response type used by this engine.
fn response_type_id(&self) -> TypeId;
/// Returns the `TypeId` of the error type used by this engine.
fn error_type_id(&self) -> TypeId;
/// Provides access to the underlying engine as a `dyn Any` for downcasting.
fn as_any(&self) -> &dyn Any;
}
/// An internal wrapper to hold a typed `AsyncEngine` behind the `AnyAsyncEngine` trait object.
///
/// This struct uses `PhantomData<fn(Req, Resp, E)>` to maintain the type relationship
/// without storing the types directly, enabling the type-erasure mechanism.
///
/// ## PhantomData Usage
/// The `PhantomData<fn(Req, Resp, E)>` ensures that the compiler knows about the
/// generic type parameters without requiring them to be `'static`, which would
/// prevent storing non-static types in the engine.
struct AnyEngineWrapper<Req, Resp, E>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
engine: Arc<dyn AsyncEngine<Req, Resp, E>>,
_phantom: PhantomData<fn(Req, Resp, E)>,
}
impl<Req, Resp, E> AnyAsyncEngine for AnyEngineWrapper<Req, Resp, E>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
fn request_type_id(&self) -> TypeId {
TypeId::of::<Req>()
}
fn response_type_id(&self) -> TypeId {
TypeId::of::<Resp>()
}
fn error_type_id(&self) -> TypeId {
TypeId::of::<E>()
}
fn as_any(&self) -> &dyn Any {
&self.engine
}
}
/// An extension trait that provides a convenient way to type-erase an `AsyncEngine`.
///
/// This trait provides the `.into_any_engine()` method on any `Arc<dyn AsyncEngine<...>>`,
/// enabling ergonomic type erasure without explicit wrapper construction.
///
/// ## Usage
/// ```rust
/// use crate::engine::AsAnyAsyncEngine;
///
/// let typed_engine: Arc<dyn AsyncEngine<String, String, ()>> = Arc::new(MyEngine::new());
/// let any_engine = typed_engine.into_any_engine();
/// ```
pub trait AsAnyAsyncEngine {
/// Converts a typed `AsyncEngine` into a type-erased `AnyAsyncEngine`.
fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine>;
}
impl<Req, Resp, E> AsAnyAsyncEngine for Arc<dyn AsyncEngine<Req, Resp, E>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine> {
Arc::new(AnyEngineWrapper {
engine: self,
_phantom: PhantomData,
})
}
}
/// An extension trait that provides a convenient method to downcast an `AnyAsyncEngine`.
///
/// This trait provides the `.downcast<Req, Resp, E>()` method on `Arc<dyn AnyAsyncEngine>`,
/// enabling safe downcasting back to the original typed engine.
///
/// ## Safety
/// The downcast method performs runtime type checking using `TypeId` comparison.
/// It will only succeed if the type parameters exactly match the original engine's types.
///
/// ## Usage
/// ```rust
/// use crate::engine::DowncastAnyAsyncEngine;
///
/// let any_engine: Arc<dyn AnyAsyncEngine> = // ... from collection
/// if let Some(typed_engine) = any_engine.downcast::<String, String, ()>() {
/// // Use the typed engine
/// let result = typed_engine.generate("hello".to_string()).await;
/// }
/// ```
pub trait DowncastAnyAsyncEngine {
/// Attempts to downcast an `AnyAsyncEngine` to a specific `AsyncEngine` type.
///
/// Returns `Some(engine)` if the type parameters match the original engine,
/// or `None` if the types don't match.
fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data;
}
impl DowncastAnyAsyncEngine for Arc<dyn AnyAsyncEngine> {
fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
if self.request_type_id() == TypeId::of::<Req>()
&& self.response_type_id() == TypeId::of::<Resp>()
&& self.error_type_id() == TypeId::of::<E>()
{
self.as_any()
.downcast_ref::<Arc<dyn AsyncEngine<Req, Resp, E>>>()
.cloned()
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
// 1. Define mock data structures
#[derive(Debug, PartialEq)]
struct Req1(String);
#[derive(Debug, PartialEq)]
struct Resp1(String);
// Dummy context provider implementation for the response
impl AsyncEngineContextProvider for Resp1 {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
// For this test, we don't need a real context.
unimplemented!()
}
}
#[derive(Debug)]
struct Err1;
// A different set of types for testing failure cases
#[derive(Debug)]
struct Req2;
#[derive(Debug)]
struct Resp2;
impl AsyncEngineContextProvider for Resp2 {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
unimplemented!()
}
}
// 2. Define a mock engine
struct MockEngine;
#[async_trait]
impl AsyncEngine<Req1, Resp1, Err1> for MockEngine {
async fn generate(&self, request: Req1) -> Result<Resp1, Err1> {
Ok(Resp1(format!("response to {}", request.0)))
}
}
#[tokio::test]
async fn test_engine_type_erasure_and_downcast() {
// 3. Create a typed engine
let typed_engine: Arc<dyn AsyncEngine<Req1, Resp1, Err1>> = Arc::new(MockEngine);
// 4. Use the extension trait to erase the type
let any_engine = typed_engine.into_any_engine();
// Check type IDs are preserved
assert_eq!(any_engine.request_type_id(), TypeId::of::<Req1>());
assert_eq!(any_engine.response_type_id(), TypeId::of::<Resp1>());
assert_eq!(any_engine.error_type_id(), TypeId::of::<Err1>());
// 5. Use the new downcast method on the Arc
let downcasted_engine = any_engine.downcast::<Req1, Resp1, Err1>();
// 6. Assert success
assert!(downcasted_engine.is_some());
// We can even use the downcasted engine
let response = downcasted_engine
.unwrap()
.generate(Req1("hello".to_string()))
.await;
assert_eq!(response.unwrap(), Resp1("response to hello".to_string()));
// 7. Assert failure for wrong types
let failed_downcast = any_engine.downcast::<Req2, Resp2, Err1>();
assert!(failed_downcast.is_none());
// 8. HashMap usage test
let mut engine_map: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
engine_map.insert("mock".to_string(), any_engine);
let retrieved_engine = engine_map.get("mock").unwrap();
let final_engine = retrieved_engine.downcast::<Req1, Resp1, Err1>().unwrap();
let final_response = final_engine.generate(Req1("world".to_string())).await;
assert_eq!(
final_response.unwrap(),
Resp1("response to world".to_string())
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment