"requirements/common.txt" did not exist on "7bc94a0fddcd62d20b40390a7efb69c7a221ae5b"
thread_pool.rs 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
    sync::{Arc, Mutex, atomic::AtomicUsize},
    thread::JoinHandle,
    time::Duration,
};

use async_trait::async_trait;
use dashmap::DashMap;
use rustc_hash::FxBuildHasher;
use tokio::sync::oneshot;

15
16
17
use super::{
    KvIndexerInterface, KvIndexerMetrics, KvRouterError, ShardSizeSnapshot, SyncIndexer, WorkerTask,
};
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::protocols::*;

/// Generic wrapper that provides [`KvIndexerInterface`] for any [`SyncIndexer`] backend.
///
/// Spawns N OS threads for processing write events (sticky-routed by WorkerId).
/// Read operations (find_matches) are executed inline on the caller's thread,
/// avoiding channel overhead and allowing reads to scale with callers.
///
/// # Architecture
///
/// ```text
///                                       +------------------------------------+
///                                       |     N Worker Threads (OS threads)  |
///                                       |                                    |
///  worker_event_channels[0] ----------> |   Thread 0: blocking recv loop     |
///  worker_event_channels[1] ----------> |   Thread 1: blocking recv loop     |
///  worker_event_channels[N] ----------> |   Thread N: blocking recv loop     |
///                                       |                                    |
///  find_matches() ---(inline)---------> |   Arc<T: SyncIndexer>              |
///                                       |   (shared, thread-safe)            |
///                                       +------------------------------------+
/// ```
pub struct ThreadPoolIndexer<T: SyncIndexer> {
    /// Shared backend - thread-safe via internal locking.
    backend: Arc<T>,

    /// Maps WorkerId to worker thread index for sticky routing.
    worker_assignments: DashMap<WorkerId, usize, FxBuildHasher>,
    /// Counter for round-robin assignment of new WorkerIds.
    worker_assignment_count: AtomicUsize,

    /// Channels to send tasks to worker threads (one per thread).
    /// Sending `WorkerTask::Terminate` signals the thread to shut down.
    worker_event_channels: Vec<flume::Sender<WorkerTask>>,

    /// Number of worker threads.
    num_workers: usize,
    /// Block size for KV cache.
    kv_block_size: u32,

    /// Handles to worker threads for joining on shutdown.
    thread_handles: Mutex<Vec<JoinHandle<()>>>,
}

impl<T: SyncIndexer> ThreadPoolIndexer<T> {
    /// Create a new `ThreadPoolIndexer` wrapping the given backend.
    ///
    /// Spawns `num_workers` OS threads, each running a blocking recv loop
    /// that processes events by calling `backend.apply_event()`.
    ///
    /// # Arguments
    ///
    /// * `backend` - The thread-safe data structure to wrap
    /// * `num_workers` - Number of worker threads for event processing
    /// * `kv_block_size` - Block size for KV cache
    ///
    /// # Panics
    ///
    /// Panics if `num_workers` is 0.
    pub fn new(backend: T, num_workers: usize, kv_block_size: u32) -> Self {
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        Self::new_with_metrics(backend, num_workers, kv_block_size, None)
    }

    /// Create a new `ThreadPoolIndexer` with optional metrics.
    ///
    /// Same as [`new`](Self::new) but allows passing `KvIndexerMetrics` so that
    /// each worker thread records `kv_cache_events_applied` counters, matching
    /// the observability of the single-threaded `KvIndexer` path.
    ///
    /// # Arguments
    ///
    /// * `backend` - The thread-safe data structure to wrap
    /// * `num_workers` - Number of worker threads for event processing
    /// * `kv_block_size` - Block size for KV cache
    /// * `metrics` - Optional metrics to record event application counts
    ///
    /// # Panics
    ///
    /// Panics if `num_workers` is 0.
    pub fn new_with_metrics(
        backend: T,
        num_workers: usize,
        kv_block_size: u32,
        metrics: Option<Arc<KvIndexerMetrics>>,
    ) -> Self {
103
        assert!(num_workers > 0, "Number of workers must be greater than 0");
104
        super::warn_on_unit_block_size("thread_pool", kv_block_size);
105
106
107
108
109
110
111
112
113

        let backend = Arc::new(backend);
        let mut worker_event_senders = Vec::new();
        let mut thread_handles = Vec::new();
        for _ in 0..num_workers {
            let (event_sender, event_receiver) = flume::unbounded::<WorkerTask>();
            worker_event_senders.push(event_sender);

            let backend = Arc::clone(&backend);
114
            let metrics = metrics.clone();
115
116

            let handle = std::thread::spawn(move || {
117
                backend.worker(event_receiver, metrics).unwrap();
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            });
            thread_handles.push(handle);
        }

        Self {
            backend,
            worker_assignments: DashMap::with_hasher(FxBuildHasher),
            worker_assignment_count: AtomicUsize::new(0),
            worker_event_channels: worker_event_senders,
            num_workers,
            kv_block_size,
            thread_handles: Mutex::new(thread_handles),
        }
    }

    /// Get a reference to the underlying backend.
    pub fn backend(&self) -> &T {
        &self.backend
    }

138
139
140
141
142
143
144
145
146
    /// Get a cloned `Arc` to the underlying backend.
    ///
    /// Useful when a caller needs to hand off an owned `Arc<T>` to a blocking
    /// task (e.g. `tokio::task::spawn_blocking`) without cloning the backend
    /// itself.
    pub fn backend_arc(&self) -> Arc<T> {
        Arc::clone(&self.backend)
    }

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    /// Wait for all worker channels to drain.
    ///
    /// Used primarily for testing and benchmarking to ensure all queued events
    /// have been picked up by workers before checking results.
    pub async fn flush(&self) {
        loop {
            let all_empty = self.worker_event_channels.iter().all(|ch| ch.is_empty());

            if all_empty {
                break;
            }

            tokio::time::sleep(Duration::from_millis(1)).await;
        }
    }
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    fn maybe_enqueue_cleanup(&self, thread_idx: usize) {
        if !self.backend.try_schedule_cleanup() {
            return;
        }

        if let Err(e) =
            self.worker_event_channels[thread_idx].send(WorkerTask::CleanupStaleChildren)
        {
            self.backend.cancel_scheduled_cleanup();
            tracing::error!(
                "Failed to send cleanup task to worker thread {}: {:?}",
                thread_idx,
                e
            );
        }
    }
179
180
}

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
impl<T: SyncIndexer> Drop for ThreadPoolIndexer<T> {
    fn drop(&mut self) {
        // Send Terminate to all worker threads so they exit their recv loops
        // and drop their Arc<T> clones. Then join the threads to ensure the
        // clones are actually dropped before the compiler drops `self.backend`.
        // Without this, worker threads may still be alive when `backend` drops,
        // keeping the Arc refcount > 0 and preventing T::drop() from running.
        for channel in self.worker_event_channels.iter() {
            let _ = channel.send(WorkerTask::Terminate);
        }
        let handles = std::mem::take(
            &mut *self
                .thread_handles
                .lock()
                .expect("thread_handles mutex poisoned"),
        );
        for handle in handles {
            let _ = handle.join();
        }
    }
}

203
204
205
206
207
208
209
210
211
212
213
214
215
216
#[async_trait]
impl<T: SyncIndexer> KvIndexerInterface for ThreadPoolIndexer<T> {
    async fn find_matches(
        &self,
        sequence: Vec<LocalBlockHash>,
    ) -> Result<OverlapScores, KvRouterError> {
        // Execute inline on caller's thread - no channel dispatch
        Ok(self.backend.find_matches(&sequence, false))
    }

    async fn find_matches_for_request(
        &self,
        tokens: &[u32],
        lora_name: Option<&str>,
217
        is_eagle: Option<bool>,
218
    ) -> Result<OverlapScores, KvRouterError> {
219
220
221
222
223
224
225
226
227
        let sequence = compute_block_hash_for_seq(
            tokens,
            self.kv_block_size,
            BlockHashOptions {
                lora_name,
                is_eagle,
                ..Default::default()
            },
        );
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        Ok(self.backend.find_matches(&sequence, false))
    }

    async fn apply_event(&self, event: RouterEvent) {
        let worker_id = event.worker_id;

        // Get or assign worker thread index using sticky round-robin
        let thread_idx = *self.worker_assignments.entry(worker_id).or_insert_with(|| {
            let idx = self
                .worker_assignment_count
                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
            idx % self.num_workers
        });

        // Send event to the assigned worker thread
        if let Err(e) = self.worker_event_channels[thread_idx].send(WorkerTask::Event(event)) {
            tracing::error!(
                "Failed to send event to worker thread {}: {:?}",
                thread_idx,
                e
            );
249
            return;
250
        }
251
252

        self.maybe_enqueue_cleanup(thread_idx);
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    }

    async fn remove_worker(&self, worker_id: WorkerId) {
        // Route to the worker's assigned thread (if any), otherwise broadcast
        // to all threads since dp_ranks may be spread across threads.
        let thread_idx = self.worker_assignments.get(&worker_id).map(|v| *v);
        match thread_idx {
            Some(idx) => {
                if let Err(e) =
                    self.worker_event_channels[idx].send(WorkerTask::RemoveWorker(worker_id))
                {
                    tracing::error!(
                        "Failed to send RemoveWorker to worker thread {}: {:?}",
                        idx,
                        e
                    );
269
                    return;
270
                }
271
272

                self.maybe_enqueue_cleanup(idx);
273
274
275
276
277
278
            }
            None => {
                // Worker was never assigned a thread - broadcast to all
                for channel in &self.worker_event_channels {
                    let _ = channel.send(WorkerTask::RemoveWorker(worker_id));
                }
279
                self.maybe_enqueue_cleanup(0);
280
281
282
283
284
285
286
287
288
289
            }
        }
    }

    async fn remove_worker_dp_rank(&self, worker_id: WorkerId, dp_rank: DpRank) {
        // Broadcast to all threads — the dp_rank may be on any thread.
        // Don't remove from worker_assignments since other dp_ranks may still exist.
        for channel in &self.worker_event_channels {
            let _ = channel.send(WorkerTask::RemoveWorkerDpRank(worker_id, dp_rank));
        }
290
        self.maybe_enqueue_cleanup(0);
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    }

    fn shutdown(&self) {
        // Send shutdown signal to all worker threads
        for channel in self.worker_event_channels.iter() {
            let _ = channel.send(WorkerTask::Terminate);
        }

        // Take ownership of thread handles and join them
        let handles = std::mem::take(
            &mut *self
                .thread_handles
                .lock()
                .expect("thread_handles mutex poisoned"),
        );
        for handle in handles {
            if let Err(e) = handle.join() {
                tracing::error!("Worker thread panicked during shutdown: {:?}", e);
            }
        }
    }

    async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
314
315
316
317
        // Send DumpEvents to every worker as a FIFO barrier: each worker must
        // finish processing all previously queued Events before it handles
        // DumpEvents, so by the time all workers respond we know the shared
        // tree (if any) reflects every event that was enqueued before this call.
318
319
320
321
322
323
324
325
326
327
328
329
330
        let mut receivers = Vec::new();

        for channel in &self.worker_event_channels {
            let (resp_tx, resp_rx) = oneshot::channel::<anyhow::Result<Vec<RouterEvent>>>();
            let dump_req = WorkerTask::DumpEvents(resp_tx);

            channel
                .send(dump_req)
                .map_err(|_| KvRouterError::IndexerOffline)?;
            receivers.push(resp_rx);
        }

        let mut all_events = Vec::new();
331
        let mut event_id_counter = 0u64;
332
333
334
335
336
337
338
339
340
341
342
343
344

        for resp_rx in receivers {
            let mut events = resp_rx
                .await
                .map_err(|_| KvRouterError::IndexerDroppedRequest)?
                .map_err(|_| KvRouterError::IndexerOffline)?;
            for event in &mut events {
                event.event.event_id = event_id_counter;
                event_id_counter += 1;
            }
            all_events.extend(events);
        }

345
346
347
348
349
350
351
352
353
        // Shared-state backends keep their tree in concurrent structures
        // readable from any thread. Now that the barrier above guarantees
        // all queued writes have landed, dump directly.
        if let Some(events) = self.backend.dump_events() {
            return Ok(events);
        }

        // Per-thread-state backends returned their events through the DumpEvents
        // responses collected above.
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        Ok(all_events)
    }

    async fn process_routing_decision_for_request(
        &self,
        _tokens_with_hashes: &mut TokensWithHashes,
        _worker: WorkerWithDpRank,
    ) -> Result<(), KvRouterError> {
        // No-op: pruning not supported in ThreadPoolIndexer
        Ok(())
    }

    async fn flush(&self) -> usize {
        let curr_size: usize = self.worker_event_channels.iter().map(|ch| ch.len()).sum();
        loop {
            let all_empty = self.worker_event_channels.iter().all(|ch| ch.is_empty());

            if all_empty {
                break;
            }

            tokio::time::sleep(Duration::from_millis(1)).await;
        }
        curr_size
    }
379
380
381
382
383
384
385
386
387
388
389
390
391

    fn shard_sizes(&self) -> Vec<ShardSizeSnapshot> {
        vec![ShardSizeSnapshot {
            shard_idx: 0,
            worker_count: self.backend.worker_count(),
            block_count: self.backend.block_count(),
            node_count: self.backend.node_count(),
        }]
    }

    fn node_edge_lengths(&self) -> Vec<usize> {
        self.backend.node_edge_lengths()
    }
392
}