// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Context Module //! //! There are two context object defined in this module: //! //! - [`Context`] is an input context which is propagated through the processing pipeline, //! up to the point where the input is pass to an [`triton_distributed::engine::AsyncEngine`] for processing. //! - [`StreamContext`] is the input context transformed into to a type erased context that maintains the inputs //! registry and visitors. `StreamAdaptors` will amend themselves to the [`StreamContext`] to allow for the use std::ops::{Deref, DerefMut}; use std::sync::Arc; use super::{AsyncEngineContext, AsyncEngineContextProvider, Data}; use crate::engine::AsyncEngineController; use async_trait::async_trait; use super::registry::Registry; pub struct Context { current: T, controller: Arc, //todo: hold this as an arc registry: Registry, stages: Vec, } impl Context { // Create a new context with initial data pub fn new(current: T) -> Self { Context { current, controller: Arc::new(Controller::default()), registry: Registry::new(), stages: Vec::new(), } } pub fn with_controller(current: T, controller: Controller) -> Self { Context { current, controller: Arc::new(controller), registry: Registry::new(), stages: Vec::new(), } } pub fn with_id(current: T, id: String) -> Self { Context { current, controller: Arc::new(Controller::new(id)), registry: Registry::new(), stages: Vec::new(), } } pub fn id(&self) -> &str { self.controller.id() } pub fn controller(&self) -> &Controller { &self.controller } /// Insert an object into the registry with a specific key. pub fn insert(&mut self, key: K, value: U) { self.registry.insert_shared(key, value); } /// Insert a unique and takable object into the registry with a specific key. pub fn insert_unique(&mut self, key: K, value: U) { self.registry.insert_unique(key, value); } /// Retrieve an object from the registry by key and type. pub fn get(&self, key: &str) -> Result, String> { self.registry.get_shared(key) } /// Clone a unique object from the registry by key and type. pub fn clone_unique(&self, key: &str) -> Result { self.registry.clone_unique(key) } /// Take a unique object from the registry by key and type. pub fn take_unique(&mut self, key: &str) -> Result { self.registry.take_unique(key) } /// Transfer the Context to a new Object without updating the registry /// This returns a tuple of the previous object and the new Context pub fn transfer(self, new_current: U) -> (T, Context) { ( self.current, Context { current: new_current, controller: self.controller, registry: self.registry, stages: self.stages, }, ) } /// Separate out the current object and context pub fn into_parts(self) -> (T, Context<()>) { self.transfer(()) } pub fn stages(&self) -> &Vec { &self.stages } pub fn add_stage(&mut self, stage: &str) { self.stages.push(stage.to_string()); } /// Transforms the current context to another type using a provided function. pub fn map(self, f: F) -> Context where F: FnOnce(T) -> U, { // Use the transfer method to move the current value out let (current, temp_context) = self.transfer(()); // Apply the transformation function to the current value let new_current = f(current); // Use transfer again to create the new context with the transformed type temp_context.transfer(new_current).1 } pub fn try_map(self, f: F) -> Result, E> where F: FnOnce(T) -> Result, U: Send + Sync + 'static, { // Use the transfer method to move the current value out let (current, temp_context) = self.transfer(()); // Apply the transformation function to the current value let new_current = f(current)?; // Use transfer again to create the new context with the transformed type Ok(temp_context.transfer(new_current).1) } } impl std::fmt::Debug for Context { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Context") .field("id", &self.controller.id()) .finish() } } // Implement Deref to allow Context to act like &T impl Deref for Context { type Target = T; fn deref(&self) -> &Self::Target { &self.current } } // Implement DerefMut to allow Context to act like &mut T impl DerefMut for Context { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.current } } // Implement the custom trait for Context impl From for Context where T: Send + Sync + 'static, { fn from(current: T) -> Self { Context::new(current) } } // Define a custom trait for conversion from Context to Context pub trait IntoContext { fn into_context(self) -> Context; } // Implement the custom trait for converting Context to Context impl IntoContext for Context where T: Send + Sync + 'static + Into, U: Send + Sync + 'static, { fn into_context(self) -> Context { self.map(|current| current.into()) } } impl AsyncEngineContextProvider for Context { fn context(&self) -> Arc { self.controller.clone() } } #[derive(Debug, Clone)] pub struct StreamContext { controller: Arc, registry: Arc, stages: Vec, } impl StreamContext { fn new(controller: Arc, registry: Registry) -> Self { StreamContext { controller, registry: Arc::new(registry), stages: Vec::new(), } } /// Retrieve an object from the registry by key and type. pub fn get(&self, key: &str) -> Result, String> { self.registry.get_shared(key) } /// Clone a unique object from the registry by key and type. pub fn clone_unique(&self, key: &str) -> Result { self.registry.clone_unique(key) } pub fn registry(&self) -> Arc { self.registry.clone() } pub fn stages(&self) -> &Vec { &self.stages } pub fn add_stage(&mut self, stage: &str) { self.stages.push(stage.to_string()); } } #[async_trait] impl AsyncEngineContext for StreamContext { fn id(&self) -> &str { self.controller.id() } fn stop(&self) { self.controller.stop(); } fn kill(&self) { self.controller.kill(); } fn stop_generating(&self) { self.controller.stop_generating(); } fn is_stopped(&self) -> bool { self.controller.is_stopped() } fn is_killed(&self) -> bool { self.controller.is_killed() } async fn stopped(&self) { self.controller.stopped().await } async fn killed(&self) { self.controller.killed().await } } impl AsyncEngineContextProvider for StreamContext { fn context(&self) -> Arc { self.controller.clone() } } impl From> for StreamContext { fn from(value: Context) -> Self { StreamContext::new(value.controller, value.registry) } } // TODO - refactor here - this came from the triton-llm-async-engine crate use tokio::sync::watch::{channel, Receiver, Sender}; #[derive(Debug, Eq, PartialEq)] enum State { Live, Stopped, Killed, } /// A context implementation with cancellation propagation. #[derive(Debug)] pub struct Controller { id: String, tx: Sender, rx: Receiver, } impl Controller { pub fn new(id: String) -> Self { let (tx, rx) = channel(State::Live); Self { id, tx, rx } } pub fn id(&self) -> &str { &self.id } } impl Default for Controller { fn default() -> Self { Self::new(uuid::Uuid::new_v4().to_string()) } } impl AsyncEngineController for Controller {} #[async_trait] impl AsyncEngineContext for Controller { fn id(&self) -> &str { &self.id } fn is_stopped(&self) -> bool { *self.rx.borrow() != State::Live } fn is_killed(&self) -> bool { *self.rx.borrow() == State::Killed } async fn stopped(&self) { let mut rx = self.rx.clone(); if *rx.borrow_and_update() != State::Live { return; } let _ = rx.changed().await; } async fn killed(&self) { let mut rx = self.rx.clone(); if *rx.borrow_and_update() == State::Killed { return; } let _ = rx.changed().await; } fn stop_generating(&self) { self.stop(); } fn stop(&self) { let _ = self.tx.send(State::Stopped); } fn kill(&self) { let _ = self.tx.send(State::Killed); } } #[cfg(test)] mod tests { use super::*; #[derive(Debug, Clone)] struct Input { value: String, } #[derive(Debug, Clone)] struct Processed { length: usize, } #[derive(Debug, Clone)] struct Final { message: String, } impl From for Processed { fn from(input: Input) -> Self { Processed { length: input.value.len(), } } } impl From for Final { fn from(processed: Processed) -> Self { Final { message: format!("Processed length: {}", processed.length), } } } #[test] fn test_insert_and_get() { let mut ctx = Context::new(Input { value: "Hello".to_string(), }); ctx.insert("key1", 42); ctx.insert("key2", "some data".to_string()); assert_eq!(*ctx.get::("key1").unwrap(), 42); assert_eq!(*ctx.get::("key2").unwrap(), "some data"); assert!(ctx.get::("key1").is_err()); // Testing a downcast failure } #[test] fn test_transfer() { let ctx = Context::new(Input { value: "Hello".to_string(), }); let (input, ctx) = ctx.transfer(Processed { length: 5 }); assert_eq!(input.value, "Hello"); assert_eq!(ctx.length, 5); } #[test] fn test_map() { let ctx = Context::new(Input { value: "Hello".to_string(), }); let ctx: Context = ctx.map(|input| input.into()); let ctx: Context = ctx.map(|processed| processed.into()); assert_eq!(ctx.current.message, "Processed length: 5"); } #[test] fn test_into_context() { let ctx = Context::new(Input { value: "Hello".to_string(), }); let ctx: Context = ctx.into_context(); let ctx: Context = ctx.into_context(); assert_eq!(ctx.current.message, "Processed length: 5"); } }