task.rs 3.51 KB
Newer Older
1
2
3
4
5
6
7
8
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use anyhow::{Result, anyhow, bail};
9
use tokio::sync::mpsc;
10
use tokio::time::Instant;
11
use uuid::Uuid;
12
13
14

use crate::common::protocols::DirectRequest;

15
use super::ReplayRouter;
16
use super::state::{
17
18
    RequestRegistry, RequestState, SharedLiveRuntimeStats, WorkloadDispatchState, now_ms,
    request_uuid,
19
20
21
22
23
24
25
26
27
28
29
};

#[derive(Clone)]
pub(super) struct RequestTaskContext {
    pub(super) senders: Arc<[mpsc::UnboundedSender<DirectRequest>]>,
    pub(super) router: Arc<ReplayRouter>,
    pub(super) requests: RequestRegistry,
    pub(super) stats: Arc<SharedLiveRuntimeStats>,
    pub(super) workload: Option<Arc<WorkloadDispatchState>>,
}

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/// Releases a `WorkloadDriver` cap slot on drop if `mark_completed` was not called.
/// Preserves the drop-safety of the old `OwnedSemaphorePermit` so a cancelled or
/// panicking request task can't leak capacity.
pub(super) struct InFlightGuard {
    dispatch: Arc<WorkloadDispatchState>,
    uuid: Uuid,
    completed: bool,
}

impl InFlightGuard {
    pub(super) fn new(dispatch: Arc<WorkloadDispatchState>, uuid: Uuid) -> Self {
        Self {
            dispatch,
            uuid,
            completed: false,
        }
    }

    pub(super) fn mark_completed(&mut self) {
        self.completed = true;
    }
}

impl Drop for InFlightGuard {
    fn drop(&mut self) {
        if self.completed {
            return;
        }
        if let Ok(mut driver) = self.dispatch.driver.lock() {
            driver.release_cap_slot(self.uuid);
        }
        self.dispatch.wakeup.notify_waiters();
    }
}

65
66
67
68
69
70
71
pub(super) async fn wait_for_workload_progress<F>(
    next_ready_ms: Option<f64>,
    start: Instant,
    mut wake: Pin<&mut F>,
) where
    F: Future<Output = ()>,
{
72
73
    match next_ready_ms {
        Some(next_ready_ms) => {
74
75
76
77
78
79
            let deadline = start + tokio::time::Duration::from_secs_f64(next_ready_ms / 1000.0);
            tokio::select! {
                _ = tokio::time::sleep_until(deadline) => {}
                _ = wake.as_mut() => {}
            }
        }
80
        None => {
81
82
83
84
85
86
87
88
            wake.as_mut().await;
        }
    }
}

pub(super) async fn run_request_task(
    ctx: RequestTaskContext,
    request: DirectRequest,
89
    mut guard: Option<InFlightGuard>,
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
) -> Result<()> {
    let uuid = request_uuid(&request)?;

    let worker_idx = ctx
        .router
        .select_worker(&request, ctx.senders.len())
        .await?;
    if worker_idx >= ctx.senders.len() {
        bail!("online replay selected unknown worker index {worker_idx}");
    }

    let state = Arc::new(RequestState::default());
    ctx.requests.insert(uuid, Arc::clone(&state));
    if let Err(error) = ctx.senders[worker_idx].send(request) {
        ctx.requests.remove(&uuid);
        return Err(anyhow!(
            "online replay failed to dispatch request to worker {worker_idx}: {error}"
        ));
    }

    ctx.stats.record_dispatch(worker_idx);
    state.wait_for_completion().await;
    ctx.stats.record_completion();
    ctx.requests.remove(&uuid);
    if let Some(workload) = ctx.workload.as_ref() {
        let completion_ms = now_ms(workload.start);
        workload
            .driver
            .lock()
            .unwrap()
            .on_complete(uuid, completion_ms)?;
        workload.wakeup.notify_waiters();
122
123
124
        if let Some(guard) = guard.as_mut() {
            guard.mark_completed();
        }
125
126
127
    }
    Ok(())
}