Unverified Commit ddcba74b authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Worker Management Workflow Engine (#11868)

parent 0917c5da
......@@ -4,17 +4,24 @@
//! them asynchronously in background worker tasks.
use std::{
collections::HashMap,
sync::{Arc, Weak},
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use metrics::{counter, gauge, histogram};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::{
core::WorkerManager,
config::{RouterConfig, RoutingMode},
core::{
workflow::{
WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus,
},
WorkerManager,
},
metrics::RouterMetrics,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
server::AppContext,
};
......@@ -24,6 +31,7 @@ use crate::{
pub enum Job {
AddWorker { config: Box<WorkerConfigRequest> },
RemoveWorker { url: String },
InitializeWorkersFromConfig { router_config: Box<RouterConfig> },
}
impl Job {
......@@ -32,6 +40,7 @@ impl Job {
match self {
Job::AddWorker { .. } => "AddWorker",
Job::RemoveWorker { .. } => "RemoveWorker",
Job::InitializeWorkersFromConfig { .. } => "InitializeWorkersFromConfig",
}
}
......@@ -40,6 +49,7 @@ impl Job {
match self {
Job::AddWorker { config } => &config.url,
Job::RemoveWorker { url } => url,
Job::InitializeWorkersFromConfig { .. } => "startup",
}
}
}
......@@ -98,7 +108,7 @@ impl Default for JobQueueConfig {
fn default() -> Self {
Self {
queue_capacity: 1000,
worker_count: 2,
worker_count: 10,
}
}
}
......@@ -166,7 +176,7 @@ impl JobQueue {
pub async fn submit(&self, job: Job) -> Result<(), String> {
// Check if context is still alive before accepting jobs
if self.context.upgrade().is_none() {
counter!("sgl_router_job_shutdown_rejected_total").increment(1);
RouterMetrics::record_job_shutdown_rejected();
return Err("Job queue shutting down: AppContext dropped".to_string());
}
......@@ -183,8 +193,7 @@ impl JobQueue {
match self.tx.send(job).await {
Ok(_) => {
let queue_depth = self.tx.max_capacity() - self.tx.capacity();
gauge!("sgl_router_job_queue_depth").set(queue_depth as f64);
RouterMetrics::set_job_queue_depth(queue_depth);
info!(
"Job submitted: type={}, worker={}, queue_depth={}",
job_type, worker_url, queue_depth
......@@ -192,8 +201,7 @@ impl JobQueue {
Ok(())
}
Err(_) => {
counter!("sgl_router_job_queue_full_total").increment(1);
// Remove status on failure
RouterMetrics::record_job_queue_full();
self.status_map.remove(&worker_url);
Err("Worker job queue full".to_string())
}
......@@ -246,39 +254,16 @@ impl JobQueue {
// Upgrade weak reference to process job
match context.upgrade() {
Some(ctx) => {
// Execute job
let result = Self::execute_job(&job, &ctx).await;
let duration = start.elapsed();
// Record metrics
histogram!("sgl_router_job_duration_seconds", "job_type" => job_type.clone())
.record(duration.as_secs_f64());
match result {
Ok(message) => {
counter!("sgl_router_job_success_total", "job_type" => job_type.clone())
.increment(1);
// Remove status on success - worker in registry is sufficient
status_map.remove(&worker_url);
info!(
"Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}",
worker_id, job_type, worker_url, duration.as_secs_f64(), message
);
}
Err(error) => {
counter!("sgl_router_job_failure_total", "job_type" => job_type.clone())
.increment(1);
// Keep failed status for API to report error details
status_map.insert(
worker_url.clone(),
JobStatus::failed(&job_type, &worker_url, error.clone()),
);
warn!(
"Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}",
worker_id, job_type, worker_url, duration.as_secs_f64(), error
);
}
}
Self::record_job_completion(
&job_type,
&worker_url,
worker_id,
duration,
&result,
&status_map,
);
}
None => {
let error_msg = "AppContext dropped".to_string();
......@@ -311,12 +296,28 @@ impl JobQueue {
async fn execute_job(job: &Job, context: &Arc<AppContext>) -> Result<String, String> {
match job {
Job::AddWorker { config } => {
// Register worker with is_healthy=false
let worker =
WorkerManager::add_worker_from_config(config.as_ref(), context).await?;
let engine = context
.workflow_engine
.get()
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
let instance_id = Self::start_worker_workflow(engine, config, context).await?;
info!(
"Started worker registration workflow for {} (instance: {})",
config.url, instance_id
);
let timeout_duration =
Duration::from_secs(context.router_config.worker_startup_timeout_secs + 30);
// Validate and activate
WorkerManager::validate_and_activate_worker(&worker, context).await
Self::wait_for_workflow_completion(
engine,
instance_id,
&config.url,
timeout_duration,
)
.await
}
Job::RemoveWorker { url } => {
let result = WorkerManager::remove_worker(url, context);
......@@ -326,6 +327,204 @@ impl JobQueue {
}
result
}
Job::InitializeWorkersFromConfig { router_config } => {
let api_key = router_config.api_key.clone();
let mut worker_count = 0;
// Create iterator of (url, worker_type, bootstrap_port) tuples based on mode
let workers: Vec<(String, &str, Option<u16>)> = match &router_config.mode {
RoutingMode::Regular { worker_urls } => worker_urls
.iter()
.map(|url| (url.clone(), "regular", None))
.collect(),
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
..
} => {
let prefill_workers = prefill_urls
.iter()
.map(|(url, port)| (url.clone(), "prefill", *port));
let decode_workers =
decode_urls.iter().map(|url| (url.clone(), "decode", None));
prefill_workers.chain(decode_workers).collect()
}
RoutingMode::OpenAI { .. } => {
info!("OpenAI mode: no workers to initialize");
return Ok("OpenAI mode: no workers to initialize".to_string());
}
};
info!(
"Creating AddWorker jobs for {} workers from config",
workers.len()
);
// Process all workers with unified loop
for (url, worker_type, bootstrap_port) in workers {
let url_for_error = url.clone(); // Clone for error message
let config = WorkerConfigRequest {
url,
api_key: api_key.clone(),
worker_type: Some(worker_type.to_string()),
labels: HashMap::new(),
model_id: None,
priority: None,
cost: None,
tokenizer_path: None,
reasoning_parser: None,
tool_parser: None,
chat_template: None,
bootstrap_port,
health_check_timeout_secs: router_config.health_check.timeout_secs,
health_check_interval_secs: router_config.health_check.check_interval_secs,
health_success_threshold: router_config.health_check.success_threshold,
health_failure_threshold: router_config.health_check.failure_threshold,
max_connection_attempts: router_config.health_check.success_threshold * 10,
dp_aware: router_config.dp_aware,
};
let job = Job::AddWorker {
config: Box::new(config),
};
if let Some(queue) = context.worker_job_queue.get() {
queue.submit(job).await.map_err(|e| {
format!(
"Failed to submit AddWorker job for {} worker {}: {}",
worker_type, url_for_error, e
)
})?;
worker_count += 1;
} else {
return Err("JobQueue not available".to_string());
}
}
Ok(format!("Submitted {} AddWorker jobs", worker_count))
}
}
}
/// Start a workflow and return its instance ID
async fn start_worker_workflow(
engine: &Arc<WorkflowEngine>,
config: &WorkerConfigRequest,
context: &Arc<AppContext>,
) -> Result<WorkflowInstanceId, String> {
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
workflow_context.set("worker_config", config.clone());
workflow_context.set_arc("app_context", Arc::clone(context));
engine
.start_workflow(WorkflowId::new("worker_registration"), workflow_context)
.await
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
}
/// Wait for workflow completion with adaptive polling
async fn wait_for_workflow_completion(
engine: &Arc<WorkflowEngine>,
instance_id: WorkflowInstanceId,
worker_url: &str,
timeout_duration: Duration,
) -> Result<String, String> {
let start = std::time::Instant::now();
let mut poll_interval = Duration::from_millis(100);
let max_poll_interval = Duration::from_millis(2000);
let poll_backoff = Duration::from_millis(200);
loop {
// Check timeout
if start.elapsed() > timeout_duration {
return Err(format!(
"Workflow timeout after {}s for worker {}",
timeout_duration.as_secs(),
worker_url
));
}
// Get workflow status
let state = engine
.get_status(instance_id)
.map_err(|e| format!("Failed to get workflow status: {:?}", e))?;
let result = match state.status {
WorkflowStatus::Completed => Ok(format!(
"Worker {} registered and activated successfully via workflow",
worker_url
)),
WorkflowStatus::Failed => {
let current_step = state.current_step.as_ref();
let step_name = current_step
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string());
let error_msg = current_step
.and_then(|step_id| state.step_states.get(step_id))
.and_then(|s| s.last_error.as_deref())
.unwrap_or("Unknown error");
Err(format!(
"Workflow failed at step {}: {}",
step_name, error_msg
))
}
WorkflowStatus::Cancelled => {
Err(format!("Workflow cancelled for worker {}", worker_url))
}
WorkflowStatus::Pending | WorkflowStatus::Paused | WorkflowStatus::Running => {
tokio::time::sleep(poll_interval).await;
poll_interval = (poll_interval + poll_backoff).min(max_poll_interval);
continue;
}
};
// Clean up terminal workflow states
engine.state_store().cleanup_if_terminal(instance_id);
return result;
}
}
/// Record job completion metrics and update status
fn record_job_completion(
job_type: &str,
worker_url: &str,
worker_id: usize,
duration: Duration,
result: &Result<String, String>,
status_map: &Arc<DashMap<String, JobStatus>>,
) {
RouterMetrics::record_job_duration(job_type, duration);
match result {
Ok(message) => {
RouterMetrics::record_job_success(job_type);
status_map.remove(worker_url);
info!(
"Worker {} completed job: type={}, worker={}, duration={:.3}s, result={}",
worker_id,
job_type,
worker_url,
duration.as_secs_f64(),
message
);
}
Err(error) => {
RouterMetrics::record_job_failure(job_type);
status_map.insert(
worker_url.to_string(),
JobStatus::failed(job_type, worker_url, error.clone()),
);
warn!(
"Worker {} failed job: type={}, worker={}, duration={:.3}s, error={}",
worker_id,
job_type,
worker_url,
duration.as_secs_f64(),
error
);
}
}
}
......@@ -352,15 +551,3 @@ impl JobQueue {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_job_queue_config_default() {
let config = JobQueueConfig::default();
assert_eq!(config.queue_capacity, 1000);
assert_eq!(config.worker_count, 2);
}
}
......@@ -4,6 +4,7 @@
//! - Worker trait and implementations
//! - Error types
//! - Circuit breaker for reliability
//! - Workflow engine for multi-step operations
//! - Common utilities
pub mod circuit_breaker;
......@@ -15,6 +16,7 @@ pub mod worker;
pub mod worker_builder;
pub mod worker_manager;
pub mod worker_registry;
pub mod workflow;
pub use circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
......@@ -23,8 +25,8 @@ pub use error::{WorkerError, WorkerResult};
pub use job_queue::{Job, JobQueue, JobQueueConfig};
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{
start_health_checker, worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker,
HealthChecker, HealthConfig, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
worker_to_info, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
......
......@@ -8,7 +8,6 @@ use std::{
};
use async_trait::async_trait;
use futures;
use serde_json;
use tokio::{sync::RwLock, time};
......@@ -910,89 +909,6 @@ impl HealthChecker {
}
}
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: Arc<std::sync::RwLock<Vec<Arc<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let handle = tokio::spawn(async move {
let mut interval = time::interval(Duration::from_secs(check_interval_secs));
// Counter for periodic load reset (every 10 health check cycles)
let mut check_count = 0u64;
const LOAD_RESET_INTERVAL: u64 = 10;
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::debug!("Health checker shutting down");
break;
}
check_count += 1;
// Check health of all workers
let workers_to_check = match workers.read() {
Ok(guard) => guard.clone(),
Err(poisoned) => {
tracing::error!("Worker lock poisoned: {}", poisoned);
continue;
}
};
// Periodically reset load counters to prevent drift
// Only do this when we believe all workers should be idle
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
let max_load = workers_to_check.iter().map(|w| w.load()).max().unwrap_or(0);
// Only reset if load appears to be very low (likely drift)
if max_load <= 2 {
tracing::debug!(
"Resetting load counters to prevent drift (max_load: {})",
max_load
);
for worker in &workers_to_check {
worker.reset_load();
}
}
}
// Perform health checks concurrently
let health_checks = workers_to_check.iter().map(|worker| {
let worker_url = worker.url().to_string();
let was_healthy = worker.is_healthy();
async move {
match worker.check_health_async().await {
Ok(_) => {
if !was_healthy {
tracing::info!("Worker {} is now healthy", worker_url);
}
}
Err(e) => {
if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
} else {
// Worker was already unhealthy, log at debug level
tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e);
}
}
}
}
});
// Execute all health checks concurrently
futures::future::join_all(health_checks).await;
}
});
HealthChecker { handle, shutdown }
}
/// Helper to convert Worker trait object to WorkerInfo struct
pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
let worker_type_str = match worker.worker_type() {
......
//! Workflow definition types
use std::{sync::Arc, time::Duration};
use super::{
executor::StepExecutor,
types::{FailureAction, RetryPolicy, StepId, WorkflowId},
};
/// Definition of a single step within a workflow
pub struct StepDefinition {
pub id: StepId,
pub name: String,
pub executor: Arc<dyn StepExecutor>,
pub retry_policy: Option<RetryPolicy>,
pub timeout: Option<Duration>,
pub on_failure: FailureAction,
}
impl StepDefinition {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
executor: Arc<dyn StepExecutor>,
) -> Self {
Self {
id: StepId::new(id.into()),
name: name.into(),
executor,
retry_policy: None,
timeout: None,
on_failure: FailureAction::FailWorkflow,
}
}
pub fn with_retry(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = Some(policy);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_failure_action(mut self, action: FailureAction) -> Self {
self.on_failure = action;
self
}
}
/// Complete workflow definition
pub struct WorkflowDefinition {
pub id: WorkflowId,
pub name: String,
pub steps: Vec<StepDefinition>,
pub default_retry_policy: RetryPolicy,
pub default_timeout: Duration,
}
impl WorkflowDefinition {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: WorkflowId::new(id.into()),
name: name.into(),
steps: Vec::new(),
default_retry_policy: RetryPolicy::default(),
default_timeout: Duration::from_secs(300), // 5 minutes
}
}
pub fn add_step(mut self, step: StepDefinition) -> Self {
self.steps.push(step);
self
}
pub fn with_default_retry(mut self, policy: RetryPolicy) -> Self {
self.default_retry_policy = policy;
self
}
pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = timeout;
self
}
/// Get the retry policy for a step (step-specific or default)
pub fn get_retry_policy<'a>(&'a self, step: &'a StepDefinition) -> &'a RetryPolicy {
step.retry_policy
.as_ref()
.unwrap_or(&self.default_retry_policy)
}
/// Get the timeout for a step (step-specific or default)
pub fn get_timeout(&self, step: &StepDefinition) -> Duration {
step.timeout.unwrap_or(self.default_timeout)
}
}
//! Workflow execution engine
use std::{collections::HashMap, sync::Arc, time::Duration};
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use chrono::Utc;
use parking_lot::RwLock;
use tokio::time::timeout;
use super::{
definition::{StepDefinition, WorkflowDefinition},
event::{EventBus, WorkflowEvent},
state::WorkflowStateStore,
types::*,
};
/// Linear backoff implementation that increases delay by a fixed amount each retry
struct LinearBackoff {
current: Duration,
increment: Duration,
max: Duration,
}
impl LinearBackoff {
fn new(increment: Duration, max: Duration) -> Self {
Self {
current: increment,
increment,
max,
}
}
}
impl Backoff for LinearBackoff {
fn next_backoff(&mut self) -> Option<Duration> {
let next = self.current;
self.current = (self.current + self.increment).min(self.max);
Some(next)
}
fn reset(&mut self) {
self.current = self.increment;
}
}
/// Main workflow execution engine
pub struct WorkflowEngine {
definitions: Arc<RwLock<HashMap<WorkflowId, Arc<WorkflowDefinition>>>>,
state_store: WorkflowStateStore,
event_bus: Arc<EventBus>,
}
impl WorkflowEngine {
pub fn new() -> Self {
Self {
definitions: Arc::new(RwLock::new(HashMap::new())),
state_store: WorkflowStateStore::new(),
event_bus: Arc::new(EventBus::new()),
}
}
/// Start a background task to periodically clean up old workflow states
///
/// This prevents unbounded memory growth by removing completed/failed workflows
/// that are older than the specified TTL.
///
/// # Arguments
///
/// * `ttl` - Time-to-live for terminal workflows (default: 1 hour)
/// * `interval` - How often to run cleanup (default: 5 minutes)
///
/// # Returns
///
/// A join handle for the cleanup task that can be used to stop it.
pub fn start_cleanup_task(
&self,
ttl: Option<Duration>,
interval: Option<Duration>,
) -> tokio::task::JoinHandle<()> {
let state_store = self.state_store.clone();
let ttl = ttl.unwrap_or(Duration::from_secs(3600)); // 1 hour default
let interval = interval.unwrap_or(Duration::from_secs(300)); // 5 minutes default
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
state_store.cleanup_old_workflows(ttl);
}
})
}
/// Register a workflow definition
pub fn register_workflow(&self, definition: WorkflowDefinition) {
let id = definition.id.clone();
self.definitions.write().insert(id, Arc::new(definition));
}
/// Get the event bus for subscribing to workflow events
pub fn event_bus(&self) -> Arc<EventBus> {
Arc::clone(&self.event_bus)
}
/// Get the state store
pub fn state_store(&self) -> &WorkflowStateStore {
&self.state_store
}
/// Start a new workflow instance
pub async fn start_workflow(
&self,
definition_id: WorkflowId,
context: WorkflowContext,
) -> WorkflowResult<WorkflowInstanceId> {
// Get workflow definition
let definition = {
let definitions = self.definitions.read();
definitions
.get(&definition_id)
.cloned()
.ok_or_else(|| WorkflowError::DefinitionNotFound(definition_id.clone()))?
};
// Create new workflow instance
let instance_id = context.instance_id;
let mut state = WorkflowState::new(instance_id, definition_id.clone());
state.status = WorkflowStatus::Running;
state.context = context;
// Initialize step states
for step in &definition.steps {
state
.step_states
.insert(step.id.clone(), StepState::default());
}
// Save initial state
self.state_store.save(state)?;
// Emit workflow started event
self.event_bus
.publish(WorkflowEvent::WorkflowStarted {
instance_id,
definition_id,
})
.await;
// Execute workflow in background
let engine = self.clone_for_execution();
let def = Arc::clone(&definition);
tokio::spawn(async move {
if let Err(e) = engine.execute_workflow(instance_id, def).await {
tracing::error!(instance_id = %instance_id, error = ?e, "Workflow execution failed");
}
});
Ok(instance_id)
}
/// Execute a workflow (internal)
async fn execute_workflow(
&self,
instance_id: WorkflowInstanceId,
definition: Arc<WorkflowDefinition>,
) -> WorkflowResult<()> {
let start_time = std::time::Instant::now();
for step in &definition.steps {
// Check if workflow was cancelled
let state = self.state_store.load(instance_id)?;
if state.status == WorkflowStatus::Cancelled {
self.event_bus
.publish(WorkflowEvent::WorkflowCancelled { instance_id })
.await;
return Ok(());
}
// Execute step with retry
match self
.execute_step_with_retry(instance_id, step, &definition)
.await
{
Ok(StepResult::Success) => {
// Continue to next step
}
Ok(StepResult::Skip) => {
// Step was skipped, continue to next
continue;
}
Ok(StepResult::Failure) | Err(_) => {
// Handle failure based on failure action
match step.on_failure {
FailureAction::FailWorkflow => {
let error_msg = format!("Step {} failed", step.id);
self.state_store.update(instance_id, |s| {
s.status = WorkflowStatus::Failed;
})?;
self.event_bus
.publish(WorkflowEvent::WorkflowFailed {
instance_id,
failed_step: step.id.clone(),
error: error_msg,
})
.await;
return Ok(());
}
FailureAction::ContinueNextStep => {
// Mark step as skipped and continue
self.state_store.update(instance_id, |s| {
if let Some(step_state) = s.step_states.get_mut(&step.id) {
step_state.status = StepStatus::Skipped;
}
})?;
continue;
}
FailureAction::RetryIndefinitely => {
// This should not happen as execute_step_with_retry handles it
unreachable!("RetryIndefinitely should be handled in retry logic");
}
}
}
}
}
// Workflow completed successfully
self.state_store.update(instance_id, |s| {
s.status = WorkflowStatus::Completed;
})?;
let duration = start_time.elapsed();
self.event_bus
.publish(WorkflowEvent::WorkflowCompleted {
instance_id,
duration,
})
.await;
Ok(())
}
/// Execute a step with retry logic
async fn execute_step_with_retry(
&self,
instance_id: WorkflowInstanceId,
step: &StepDefinition,
definition: &WorkflowDefinition,
) -> WorkflowResult<StepResult> {
let retry_policy = definition.get_retry_policy(step);
let step_timeout = definition.get_timeout(step);
let mut attempt = 1;
let max_attempts = if matches!(step.on_failure, FailureAction::RetryIndefinitely) {
u32::MAX
} else {
retry_policy.max_attempts
};
let mut backoff = Self::create_backoff(&retry_policy.backoff);
loop {
// Check for cancellation before starting/retrying step
{
let state = self.state_store.load(instance_id)?;
if state.status == WorkflowStatus::Cancelled {
return Err(WorkflowError::Cancelled(instance_id));
}
}
// Update step state
self.state_store.update(instance_id, |s| {
s.current_step = Some(step.id.clone());
if let Some(step_state) = s.step_states.get_mut(&step.id) {
step_state.status = if attempt == 1 {
StepStatus::Running
} else {
StepStatus::Retrying
};
step_state.attempt = attempt;
step_state.started_at = Some(Utc::now());
}
})?;
// Emit step started event
self.event_bus
.publish(WorkflowEvent::StepStarted {
instance_id,
step_id: step.id.clone(),
attempt,
})
.await;
// Get current context
let mut context = self.state_store.load(instance_id)?.context;
// Execute step with timeout
let step_start = std::time::Instant::now();
let result = timeout(step_timeout, step.executor.execute(&mut context)).await;
let step_duration = step_start.elapsed();
// Save updated context
self.state_store.update(instance_id, |s| {
s.context = context.clone();
})?;
match result {
Ok(Ok(StepResult::Success)) => {
// Step succeeded
self.state_store.update(instance_id, |s| {
if let Some(step_state) = s.step_states.get_mut(&step.id) {
step_state.status = StepStatus::Succeeded;
step_state.completed_at = Some(Utc::now());
}
})?;
self.event_bus
.publish(WorkflowEvent::StepSucceeded {
instance_id,
step_id: step.id.clone(),
duration: step_duration,
})
.await;
// Call on_success hook
if let Err(e) = step.executor.on_success(&context).await {
tracing::warn!(step_id = %step.id, error = ?e, "on_success hook failed");
}
return Ok(StepResult::Success);
}
Ok(Ok(StepResult::Skip)) => {
return Ok(StepResult::Skip);
}
Ok(Ok(StepResult::Failure)) | Ok(Err(_)) | Err(_) => {
let (error_msg, should_retry) = match result {
Ok(Err(e)) => {
let msg = format!("{}", e);
let retryable = step.executor.is_retryable(&e);
(msg, retryable)
}
Err(_) => (
format!("Step timeout after {:?}", step_timeout),
true, // Timeouts are retryable
),
_ => ("Step failed".to_string(), false),
};
let will_retry = should_retry && attempt < max_attempts;
// Update step state
self.state_store.update(instance_id, |s| {
if let Some(step_state) = s.step_states.get_mut(&step.id) {
step_state.status = if will_retry {
StepStatus::Retrying
} else {
StepStatus::Failed
};
step_state.last_error = Some(error_msg.clone());
if !will_retry {
step_state.completed_at = Some(Utc::now());
}
}
})?;
// Emit step failed event
self.event_bus
.publish(WorkflowEvent::StepFailed {
instance_id,
step_id: step.id.clone(),
error: error_msg.clone(),
will_retry,
})
.await;
if will_retry {
// Calculate backoff delay
let delay = backoff
.next_backoff()
.unwrap_or_else(|| Duration::from_secs(1));
self.event_bus
.publish(WorkflowEvent::StepRetrying {
instance_id,
step_id: step.id.clone(),
attempt: attempt + 1,
delay,
})
.await;
tokio::time::sleep(delay).await;
attempt += 1;
} else {
// No more retries, call on_failure hook
// Create a generic error for the hook
let hook_error = WorkflowError::StepFailed {
step_id: step.id.clone(),
message: error_msg,
};
if let Err(hook_err) = step.executor.on_failure(&context, &hook_error).await
{
tracing::warn!(step_id = %step.id, error = ?hook_err, "on_failure hook failed");
}
return Ok(StepResult::Failure);
}
}
}
}
}
/// Create a backoff instance from strategy
fn create_backoff(strategy: &BackoffStrategy) -> Box<dyn Backoff + Send> {
match strategy {
BackoffStrategy::Fixed(duration) => {
// For fixed backoff, use exponential with multiplier 1.0
let backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(*duration)
.with_multiplier(1.0)
.with_max_interval(*duration)
.with_max_elapsed_time(None)
.build();
Box::new(backoff)
}
BackoffStrategy::Exponential { base, max } => {
let backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(*base)
.with_max_interval(*max)
.with_max_elapsed_time(None)
.build();
Box::new(backoff)
}
BackoffStrategy::Linear { increment, max } => {
// Use proper linear backoff: increment, 2*increment, 3*increment, ...
Box::new(LinearBackoff::new(*increment, *max))
}
}
}
/// Cancel a running workflow
pub async fn cancel_workflow(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> {
self.state_store.update(instance_id, |s| {
s.status = WorkflowStatus::Cancelled;
})?;
self.event_bus
.publish(WorkflowEvent::WorkflowCancelled { instance_id })
.await;
Ok(())
}
/// Get workflow status
pub fn get_status(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<WorkflowState> {
self.state_store.load(instance_id)
}
/// Clone engine for async execution
fn clone_for_execution(&self) -> Self {
Self {
definitions: Arc::clone(&self.definitions),
state_store: self.state_store.clone(),
event_bus: Arc::clone(&self.event_bus),
}
}
}
impl Default for WorkflowEngine {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for WorkflowEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkflowEngine")
.field("definitions_count", &self.definitions.read().len())
.field("state_count", &self.state_store.count())
.finish()
}
}
//! Workflow event system for observability and monitoring
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use tokio::sync::RwLock;
use tracing::{error, info, warn};
use super::types::{StepId, WorkflowId, WorkflowInstanceId};
/// Events emitted by the workflow engine
#[derive(Debug, Clone)]
pub enum WorkflowEvent {
WorkflowStarted {
instance_id: WorkflowInstanceId,
definition_id: WorkflowId,
},
StepStarted {
instance_id: WorkflowInstanceId,
step_id: StepId,
attempt: u32,
},
StepSucceeded {
instance_id: WorkflowInstanceId,
step_id: StepId,
duration: Duration,
},
StepFailed {
instance_id: WorkflowInstanceId,
step_id: StepId,
error: String,
will_retry: bool,
},
StepRetrying {
instance_id: WorkflowInstanceId,
step_id: StepId,
attempt: u32,
delay: Duration,
},
WorkflowCompleted {
instance_id: WorkflowInstanceId,
duration: Duration,
},
WorkflowFailed {
instance_id: WorkflowInstanceId,
failed_step: StepId,
error: String,
},
WorkflowCancelled {
instance_id: WorkflowInstanceId,
},
}
/// Trait for subscribing to workflow events
#[async_trait]
pub trait EventSubscriber: Send + Sync {
async fn on_event(&self, event: &WorkflowEvent);
}
/// Event bus for publishing and subscribing to workflow events
pub struct EventBus {
subscribers: Arc<RwLock<Vec<Arc<dyn EventSubscriber>>>>,
}
impl EventBus {
pub fn new() -> Self {
Self {
subscribers: Arc::new(RwLock::new(Vec::new())),
}
}
/// Subscribe to workflow events
pub async fn subscribe(&self, subscriber: Arc<dyn EventSubscriber>) {
self.subscribers.write().await.push(subscriber);
}
/// Publish an event to all subscribers
pub async fn publish(&self, event: WorkflowEvent) {
let subscribers = self.subscribers.read().await;
for subscriber in subscribers.iter() {
subscriber.on_event(&event).await;
}
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
/// Logging subscriber that logs events using tracing
pub struct LoggingSubscriber;
#[async_trait]
impl EventSubscriber for LoggingSubscriber {
async fn on_event(&self, event: &WorkflowEvent) {
match event {
WorkflowEvent::WorkflowStarted {
instance_id,
definition_id,
} => {
info!(
instance_id = %instance_id,
definition_id = %definition_id,
"Workflow started"
);
}
WorkflowEvent::StepStarted {
instance_id,
step_id,
attempt,
} => {
info!(
instance_id = %instance_id,
step_id = %step_id,
attempt = attempt,
"Step started"
);
}
WorkflowEvent::StepSucceeded {
instance_id,
step_id,
duration,
} => {
info!(
instance_id = %instance_id,
step_id = %step_id,
duration_ms = duration.as_millis(),
"Step succeeded"
);
}
WorkflowEvent::StepFailed {
instance_id,
step_id,
error,
will_retry,
} => {
warn!(
instance_id = %instance_id,
step_id = %step_id,
error = error,
will_retry = will_retry,
"Step failed"
);
}
WorkflowEvent::StepRetrying {
instance_id,
step_id,
attempt,
delay,
} => {
info!(
instance_id = %instance_id,
step_id = %step_id,
attempt = attempt,
delay_ms = delay.as_millis(),
"Step retrying"
);
}
WorkflowEvent::WorkflowCompleted {
instance_id,
duration,
} => {
info!(
instance_id = %instance_id,
duration_ms = duration.as_millis(),
"Workflow completed"
);
}
WorkflowEvent::WorkflowFailed {
instance_id,
failed_step,
error,
} => {
error!(
instance_id = %instance_id,
failed_step = %failed_step,
error = error,
"Workflow failed"
);
}
WorkflowEvent::WorkflowCancelled { instance_id } => {
info!(instance_id = %instance_id, "Workflow cancelled");
}
}
}
}
//! Step executor trait and implementations
use async_trait::async_trait;
use super::types::{StepResult, WorkflowContext, WorkflowError, WorkflowResult};
/// Trait for executing individual workflow steps
#[async_trait]
pub trait StepExecutor: Send + Sync {
/// Execute the step with the given context
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult>;
/// Check if an error is retry-able
///
/// Override this method to customize which errors should trigger retries.
/// By default, all errors are considered retry-able.
fn is_retryable(&self, _error: &WorkflowError) -> bool {
true
}
/// Called when the step succeeds
///
/// This hook allows steps to perform cleanup or additional actions
/// after successful execution.
async fn on_success(&self, _context: &WorkflowContext) -> WorkflowResult<()> {
Ok(())
}
/// Called when the step fails after all retries
///
/// This hook allows steps to perform cleanup or compensation logic
/// when the step cannot complete successfully.
async fn on_failure(
&self,
_context: &WorkflowContext,
_error: &WorkflowError,
) -> WorkflowResult<()> {
Ok(())
}
}
/// Simple function-based step executor
pub struct FunctionStep<F>
where
F: Fn(
&mut WorkflowContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
> + Send
+ Sync,
{
func: F,
}
impl<F> FunctionStep<F>
where
F: Fn(
&mut WorkflowContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
> + Send
+ Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
#[async_trait]
impl<F> StepExecutor for FunctionStep<F>
where
F: Fn(
&mut WorkflowContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = WorkflowResult<StepResult>> + Send + '_>,
> + Send
+ Sync,
{
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
(self.func)(context).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::workflow::types::WorkflowInstanceId;
struct TestStep {
should_succeed: bool,
}
#[async_trait]
impl StepExecutor for TestStep {
async fn execute(&self, _context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
if self.should_succeed {
Ok(StepResult::Success)
} else {
Err(WorkflowError::StepFailed {
step_id: crate::core::workflow::types::StepId::new("test"),
message: "test error".to_string(),
})
}
}
}
#[tokio::test]
async fn test_step_executor_success() {
let step = TestStep {
should_succeed: true,
};
let mut context = WorkflowContext::new(WorkflowInstanceId::new());
let result = step.execute(&mut context).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), StepResult::Success);
}
#[tokio::test]
async fn test_step_executor_failure() {
let step = TestStep {
should_succeed: false,
};
let mut context = WorkflowContext::new(WorkflowInstanceId::new());
let result = step.execute(&mut context).await;
assert!(result.is_err());
}
}
//! Workflow engine for managing multi-step operations
mod definition;
mod engine;
mod event;
mod executor;
mod state;
pub mod steps;
pub mod types;
// Re-export main types
pub use definition::{StepDefinition, WorkflowDefinition};
pub use engine::WorkflowEngine;
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
pub use executor::{FunctionStep, StepExecutor};
pub use state::WorkflowStateStore;
pub use steps::create_worker_registration_workflow;
pub use types::*;
//! Workflow state management
use std::{collections::HashMap, sync::Arc};
use parking_lot::RwLock;
use super::types::{
WorkflowError, WorkflowInstanceId, WorkflowResult, WorkflowState, WorkflowStatus,
};
/// In-memory state storage for workflow instances
#[derive(Clone)]
pub struct WorkflowStateStore {
states: Arc<RwLock<HashMap<WorkflowInstanceId, WorkflowState>>>,
}
impl WorkflowStateStore {
pub fn new() -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Save workflow state
///
/// # Warning
///
/// This emits a warning if the workflow context contains unserializable data,
/// which would be lost if state persistence is later implemented.
pub fn save(&self, state: WorkflowState) -> WorkflowResult<()> {
if state.context.has_unserializable_data() {
tracing::warn!(
instance_id = %state.instance_id,
data_count = state.context.data_len(),
"Saving workflow state with {} unserializable context entries. \
This data cannot be persisted and will be lost on restart.",
state.context.data_len()
);
}
self.states.write().insert(state.instance_id, state);
Ok(())
}
/// Load workflow state by instance ID
pub fn load(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<WorkflowState> {
self.states
.read()
.get(&instance_id)
.cloned()
.ok_or(WorkflowError::NotFound(instance_id))
}
/// List all active workflows (Running or Pending)
pub fn list_active(&self) -> WorkflowResult<Vec<WorkflowState>> {
let states = self.states.read();
Ok(states
.values()
.filter(|s| matches!(s.status, WorkflowStatus::Running | WorkflowStatus::Pending))
.cloned()
.collect())
}
/// List all workflows
pub fn list_all(&self) -> WorkflowResult<Vec<WorkflowState>> {
let states = self.states.read();
Ok(states.values().cloned().collect())
}
/// Delete workflow state
pub fn delete(&self, instance_id: WorkflowInstanceId) -> WorkflowResult<()> {
self.states.write().remove(&instance_id);
Ok(())
}
/// Update workflow state using a closure
pub fn update<F>(&self, instance_id: WorkflowInstanceId, f: F) -> WorkflowResult<()>
where
F: FnOnce(&mut WorkflowState),
{
let mut states = self.states.write();
let state = states
.get_mut(&instance_id)
.ok_or(WorkflowError::NotFound(instance_id))?;
f(state);
state.updated_at = chrono::Utc::now();
Ok(())
}
/// Get count of workflows by status
pub fn count_by_status(&self, status: WorkflowStatus) -> usize {
self.states
.read()
.values()
.filter(|s| s.status == status)
.count()
}
/// Get total count of all workflows
pub fn count(&self) -> usize {
self.states.read().len()
}
/// Clean up old completed/failed/cancelled workflows beyond a time threshold
///
/// This prevents unbounded memory growth by removing workflow states that
/// have been in a terminal state (Completed, Failed, Cancelled) for longer
/// than the specified TTL (time-to-live).
///
/// Active workflows (Running, Pending, Paused) are never cleaned up.
///
/// # Arguments
///
/// * `ttl` - Time-to-live for terminal workflows. Workflows in terminal states
/// older than this will be removed.
///
/// # Returns
///
/// The number of workflow states removed.
pub fn cleanup_old_workflows(&self, ttl: std::time::Duration) -> usize {
let now = chrono::Utc::now();
let mut states = self.states.write();
let initial_count = states.len();
states.retain(|_, state| {
// Keep active workflows
if matches!(
state.status,
WorkflowStatus::Running | WorkflowStatus::Pending | WorkflowStatus::Paused
) {
return true;
}
// For terminal workflows, check age
let age = now
.signed_duration_since(state.updated_at)
.to_std()
.unwrap_or_default();
age < ttl
});
let removed_count = initial_count - states.len();
if removed_count > 0 {
tracing::info!(
removed = removed_count,
remaining = states.len(),
"Cleaned up old workflow states"
);
}
removed_count
}
/// Clean up a specific completed workflow immediately
///
/// This is useful for cleaning up workflows right after they complete
/// when you know they won't be queried again.
pub fn cleanup_if_terminal(&self, instance_id: WorkflowInstanceId) -> bool {
let mut states = self.states.write();
if let Some(state) = states.get(&instance_id) {
if matches!(
state.status,
WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled
) {
states.remove(&instance_id);
return true;
}
}
false
}
}
impl Default for WorkflowStateStore {
fn default() -> Self {
Self::new()
}
}
//! Workflow step implementations
//!
//! This module contains concrete step implementations for various workflows:
//! - Worker registration and activation
//! - Future: Tokenizer fetching, LoRA updates, etc.
pub mod worker_registration;
pub use worker_registration::{
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
};
This diff is collapsed.
//! Core workflow types and definitions
use std::{collections::HashMap, fmt, sync::Arc, time::Duration};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Unique identifier for a workflow definition
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct WorkflowId(String);
impl WorkflowId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
}
impl fmt::Display for WorkflowId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Unique identifier for a workflow instance
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct WorkflowInstanceId(Uuid);
impl WorkflowInstanceId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl Default for WorkflowInstanceId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for WorkflowInstanceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Unique identifier for a workflow step
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StepId(String);
impl StepId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
}
impl fmt::Display for StepId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Retry policy configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub backoff: BackoffStrategy,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
backoff: BackoffStrategy::Exponential {
base: Duration::from_secs(1),
max: Duration::from_secs(30),
},
}
}
}
/// Backoff strategy for retries
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackoffStrategy {
/// Fixed delay between retries
Fixed(Duration),
/// Exponential backoff with base and max duration
Exponential { base: Duration, max: Duration },
/// Linear backoff with increment and max duration
Linear { increment: Duration, max: Duration },
}
/// Action to take when a step fails
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FailureAction {
/// Stop the entire workflow
FailWorkflow,
/// Skip this step and continue to the next
ContinueNextStep,
/// Keep retrying indefinitely until manual intervention
RetryIndefinitely,
}
/// Workflow execution status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WorkflowStatus {
Pending,
Running,
Paused,
Completed,
Failed,
Cancelled,
}
/// Step execution status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepStatus {
Pending,
Running,
Succeeded,
Failed,
Retrying,
Skipped,
}
/// State of a workflow step
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepState {
pub status: StepStatus,
pub attempt: u32,
pub last_error: Option<String>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
}
impl Default for StepState {
fn default() -> Self {
Self {
status: StepStatus::Pending,
attempt: 0,
last_error: None,
started_at: None,
completed_at: None,
}
}
}
/// Workflow instance state
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowState {
pub instance_id: WorkflowInstanceId,
pub definition_id: WorkflowId,
pub status: WorkflowStatus,
pub current_step: Option<StepId>,
pub step_states: HashMap<StepId, StepState>,
pub context: WorkflowContext,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl WorkflowState {
pub fn new(instance_id: WorkflowInstanceId, definition_id: WorkflowId) -> Self {
let now = Utc::now();
Self {
instance_id,
definition_id,
status: WorkflowStatus::Pending,
current_step: None,
step_states: HashMap::new(),
context: WorkflowContext::new(instance_id),
created_at: now,
updated_at: now,
}
}
}
/// Shared context passed between workflow steps
///
/// # Serialization Warning
///
/// The `data` field contains type-erased values that cannot be serialized.
/// This means workflow context is **not preserved** across:
/// - Process restarts
/// - State persistence to disk
/// - Network serialization
///
/// The workflow engine only supports **in-memory execution**. If you need
/// durable workflows, consider implementing a custom serializable context type.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowContext {
pub instance_id: WorkflowInstanceId,
#[serde(skip)]
data: HashMap<String, Arc<dyn std::any::Any + Send + Sync>>,
}
impl WorkflowContext {
pub fn new(instance_id: WorkflowInstanceId) -> Self {
Self {
instance_id,
data: HashMap::new(),
}
}
/// Store a value in the context (will be wrapped in Arc)
pub fn set<T: Send + Sync + 'static>(&mut self, key: impl Into<String>, value: T) {
self.data.insert(key.into(), Arc::new(value));
}
/// Store an Arc directly without double-wrapping
pub fn set_arc<T: Send + Sync + 'static>(&mut self, key: impl Into<String>, value: Arc<T>) {
self.data.insert(key.into(), value);
}
/// Retrieve a value from the context
pub fn get<T: Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
self.data
.get(key)
.and_then(|v| v.clone().downcast::<T>().ok())
}
/// Check if the context has any data that would be lost during serialization
pub fn has_unserializable_data(&self) -> bool {
!self.data.is_empty()
}
/// Get the number of context entries (useful for debugging)
pub fn data_len(&self) -> usize {
self.data.len()
}
}
/// Result returned by a step execution
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StepResult {
Success,
Failure,
Skip,
}
/// Error kinds for workflow operations
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum WorkflowError {
#[error("Workflow not found: {0}")]
NotFound(WorkflowInstanceId),
#[error("Workflow definition not found: {0}")]
DefinitionNotFound(WorkflowId),
#[error("Step failed: {step_id} - {message}")]
StepFailed { step_id: StepId, message: String },
#[error("Step timeout: {step_id}")]
StepTimeout { step_id: StepId },
#[error("Workflow cancelled: {0}")]
Cancelled(WorkflowInstanceId),
#[error("Invalid state transition: {from:?} -> {to:?}")]
InvalidStateTransition {
from: WorkflowStatus,
to: WorkflowStatus,
},
#[error("Context value not found: {0}")]
ContextValueNotFound(String),
#[error("Context value type mismatch: {0}")]
ContextTypeMismatch(String),
}
pub type WorkflowResult<T> = Result<T, WorkflowError>;
......@@ -530,6 +530,39 @@ impl RouterMetrics {
)
.increment(1);
}
pub fn set_job_queue_depth(depth: usize) {
gauge!("sgl_router_job_queue_depth").set(depth as f64);
}
pub fn record_job_duration(job_type: &str, duration: Duration) {
histogram!("sgl_router_job_duration_seconds",
"job_type" => job_type.to_string()
)
.record(duration.as_secs_f64());
}
pub fn record_job_success(job_type: &str) {
counter!("sgl_router_job_success_total",
"job_type" => job_type.to_string()
)
.increment(1);
}
pub fn record_job_failure(job_type: &str) {
counter!("sgl_router_job_failure_total",
"job_type" => job_type.to_string()
)
.increment(1);
}
pub fn record_job_queue_full() {
counter!("sgl_router_job_queue_full_total").increment(1);
}
pub fn record_job_shutdown_rejected() {
counter!("sgl_router_job_shutdown_rejected_total").increment(1);
}
}
impl TokenizerMetrics {
......
......@@ -56,6 +56,51 @@ pub struct WorkerConfigRequest {
/// Additional labels (optional)
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub labels: HashMap<String, String>,
/// Health check timeout in seconds (default: 30)
#[serde(default = "default_health_check_timeout")]
pub health_check_timeout_secs: u64,
/// Health check interval in seconds (default: 60)
#[serde(default = "default_health_check_interval")]
pub health_check_interval_secs: u64,
/// Number of successful health checks needed to mark worker as healthy (default: 2)
#[serde(default = "default_health_success_threshold")]
pub health_success_threshold: u32,
/// Number of failed health checks before marking worker as unhealthy (default: 3)
#[serde(default = "default_health_failure_threshold")]
pub health_failure_threshold: u32,
/// Maximum connection attempts during worker registration (default: 20)
#[serde(default = "default_max_connection_attempts")]
pub max_connection_attempts: u32,
/// Enable data parallelism aware scheduling (default: false)
#[serde(default)]
pub dp_aware: bool,
}
// Default value functions for serde
fn default_health_check_timeout() -> u64 {
30
}
fn default_health_check_interval() -> u64 {
60
}
fn default_health_success_threshold() -> u32 {
2
}
fn default_health_failure_threshold() -> u32 {
3
}
fn default_max_connection_attempts() -> u32 {
20
}
/// Worker information for API responses
......
......@@ -22,8 +22,8 @@ use tracing::{error, info, warn, Level};
use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{
worker_to_info, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
WorkerType,
worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor,
WorkerManager, WorkerRegistry, WorkerType,
},
data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
......@@ -77,6 +77,7 @@ pub struct AppContext {
pub configured_reasoning_parser: Option<String>,
pub configured_tool_parser: Option<String>,
pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
}
impl AppContext {
......@@ -95,6 +96,7 @@ impl AppContext {
conversation_item_storage: SharedConversationItemStorage,
load_monitor: Option<Arc<LoadMonitor>>,
worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
) -> Self {
let configured_reasoning_parser = router_config.reasoning_parser.clone();
let configured_tool_parser = router_config.tool_call_parser.clone();
......@@ -116,6 +118,7 @@ impl AppContext {
configured_reasoning_parser,
configured_tool_parser,
worker_job_queue,
workflow_engine,
}
}
}
......@@ -979,8 +982,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue (will be initialized below)
// Create empty OnceLock for worker job queue and workflow engine (will be initialized below)
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
// Create AppContext with all initialized components
let app_context = AppContext::new(
......@@ -997,6 +1001,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
);
let app_context = Arc::new(app_context);
......@@ -1008,17 +1013,38 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.set(worker_job_queue)
.expect("JobQueue should only be initialized once");
// Initialize workflow engine and register workflows
let engine = Arc::new(WorkflowEngine::new());
engine
.event_bus()
.subscribe(Arc::new(crate::core::workflow::LoggingSubscriber))
.await;
engine.register_workflow(crate::core::workflow::create_worker_registration_workflow());
app_context
.workflow_engine
.set(engine)
.expect("WorkflowEngine should only be initialized once");
info!("Workflow engine initialized with worker registration workflow");
info!(
"Initializing workers for routing mode: {:?}",
config.router_config.mode
);
WorkerManager::initialize_workers(
&config.router_config,
&app_context.worker_registry,
Some(&app_context.policy_registry),
)
.await
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
// Submit worker initialization job to queue
let job_queue = app_context
.worker_job_queue
.get()
.expect("JobQueue should be initialized");
let job = Job::InitializeWorkersFromConfig {
router_config: Box::new(config.router_config.clone()),
};
job_queue
.submit(job)
.await
.map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
let worker_stats = app_context.worker_registry.stats();
info!(
......
......@@ -18,7 +18,11 @@ use rustls;
use tokio::{task, time};
use tracing::{debug, error, info, warn};
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
use crate::{
core::{Job, WorkerManager},
protocols::worker_spec::WorkerConfigRequest,
server::AppContext,
};
#[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig {
......@@ -157,6 +161,7 @@ impl PodInfo {
}
pub fn worker_url(&self, port: u16) -> String {
// Default to http:// prefix; workflow will detect actual protocol (HTTP vs gRPC)
format!("http://{}:{}", self.ip, port)
}
}
......@@ -382,10 +387,18 @@ async fn handle_pod_event(
tool_parser: None,
chat_template: None,
api_key: None,
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
health_check_interval_secs: app_context
.router_config
.health_check
.check_interval_secs,
health_success_threshold: app_context.router_config.health_check.success_threshold,
health_failure_threshold: app_context.router_config.health_check.failure_threshold,
max_connection_attempts: app_context.router_config.health_check.success_threshold
* 20,
dp_aware: false,
};
// Submit job for async worker addition
use crate::core::Job;
let job = Job::AddWorker {
config: Box::new(config.clone()),
};
......@@ -568,6 +581,7 @@ mod tests {
configured_reasoning_parser: None,
configured_tool_parser: None,
worker_job_queue: Arc::new(std::sync::OnceLock::new()),
workflow_engine: Arc::new(std::sync::OnceLock::new()),
})
}
......@@ -815,19 +829,6 @@ mod tests {
assert!(!not_running_pod.is_healthy());
}
#[test]
fn test_pod_info_worker_url() {
let pod_info = PodInfo {
name: "p1".into(),
ip: "1.2.3.4".into(),
status: "Running".into(),
is_ready: true,
pod_type: None,
bootstrap_port: None,
};
assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080");
}
#[test]
fn test_pod_info_equality_with_pod_type() {
let pod1 = PodInfo {
......
......@@ -62,8 +62,9 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue
// Create empty OnceLock for worker job queue and workflow engine
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
Arc::new(AppContext::new(
config,
......@@ -79,6 +80,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
))
}
......
......@@ -53,8 +53,9 @@ pub fn create_test_app(
router_config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue
// Create empty OnceLock for worker job queue and workflow engine
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
// Create AppContext
let app_context = Arc::new(AppContext::new(
......@@ -71,6 +72,7 @@ pub fn create_test_app(
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
));
// Create AppState with the test router and context
......
......@@ -36,6 +36,12 @@ async fn test_policy_registry_with_router_manager() {
reasoning_parser: None,
tool_parser: None,
chat_template: None,
health_check_timeout_secs: 30,
health_check_interval_secs: 60,
health_success_threshold: 2,
health_failure_threshold: 3,
max_connection_attempts: 20,
dp_aware: false,
};
// This would normally connect to a real worker, but for testing we'll just verify the structure
......@@ -61,6 +67,12 @@ async fn test_policy_registry_with_router_manager() {
reasoning_parser: None,
tool_parser: None,
chat_template: None,
health_check_timeout_secs: 30,
health_check_interval_secs: 60,
health_success_threshold: 2,
health_failure_threshold: 3,
max_connection_attempts: 20,
dp_aware: false,
};
// The second worker should use the same policy as the first (cache_aware)
......@@ -82,6 +94,12 @@ async fn test_policy_registry_with_router_manager() {
reasoning_parser: None,
tool_parser: None,
chat_template: None,
health_check_timeout_secs: 30,
health_check_interval_secs: 60,
health_success_threshold: 2,
health_failure_threshold: 3,
max_connection_attempts: 20,
dp_aware: false,
};
let _gpt_policy = policy_registry.get_policy("gpt-4");
......
......@@ -238,8 +238,9 @@ mod test_pd_routing {
config.worker_startup_check_interval_secs,
)));
// Create empty OnceLock for worker job queue
// Create empty OnceLock for worker job queue and workflow engine
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
Arc::new(sglang_router_rs::server::AppContext::new(
config,
......@@ -255,6 +256,7 @@ mod test_pd_routing {
conversation_item_storage,
load_monitor,
worker_job_queue,
workflow_engine,
))
};
let result = RouterFactory::create_router(&app_context).await;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment