"vllm/tool_parsers/__init__.py" did not exist on "e261d37c9a5e88a6c86d32decf39f1fab7ca1f2c"
responder.rs 11.1 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use tokio::sync::mpsc;

use std::collections::HashSet;
use std::sync::Arc;

use crate::{BlockId, G2, G3, InstanceId, SequenceHash, worker::group::ParallelWorkers};
use kvbm_logical::manager::BlockManager;

use super::{BlockHolder, SessionId, messages::OnboardMessage, transport::MessageTransport};

/// Responder-side session for handling block onboarding requests.
///
/// Lifecycle:
/// 1. Spawned when receiving CreateSession
/// 2. Searches local G2 for matches
/// 3. Holds `ImmutableBlock<G2>` references (RAII)
/// 4. Sends G2Results immediately
/// 5. Searches local G3 for remaining matches (if G3 available)
/// 6. Sends G3Results
/// 7. Receives HoldBlocks and filters held G2 blocks
/// 8. Receives StageBlocks and executes G3->G2 transfers
/// 9. Sends BlocksReady when staging completes
/// 10. Sends Acknowledged
/// 11. Completes and drops (releases blocks)
pub struct ResponderSession {
    session_id: SessionId,
    instance_id: InstanceId,
    requester: InstanceId,
    g2_manager: Arc<BlockManager<G2>>,
    g3_manager: Option<Arc<BlockManager<G3>>>,
    parallel_worker: Option<Arc<dyn ParallelWorkers>>,
    transport: Arc<MessageTransport>,
    // Held blocks using BlockHolder for RAII semantics
    // Blocks are automatically released when the session drops
    held_g2_blocks: BlockHolder<G2>,
    held_g3_blocks: BlockHolder<G3>,
}

impl ResponderSession {
    /// Create a new responder session.
    pub fn new(
        session_id: SessionId,
        instance_id: InstanceId,
        requester: InstanceId,
        g2_manager: Arc<BlockManager<G2>>,
        g3_manager: Option<Arc<BlockManager<G3>>>,
        parallel_worker: Option<Arc<dyn ParallelWorkers>>,
        transport: Arc<MessageTransport>,
    ) -> Self {
        Self {
            session_id,
            instance_id,
            requester,
            g2_manager,
            g3_manager,
            parallel_worker,
            transport,
            held_g2_blocks: BlockHolder::empty(),
            held_g3_blocks: BlockHolder::empty(),
        }
    }

    /// Run the responder session task.
    ///
    /// This is the main session loop that processes messages from the channel.
    pub async fn run(
        mut self,
        mut rx: mpsc::Receiver<OnboardMessage>,
        sequence_hashes: Vec<SequenceHash>,
    ) -> Result<()> {
        // Phase 1: Immediate G2 search
        // Use scan_matches instead of match_blocks to find all matching blocks
        // without stopping on first miss (supports partial sequence matching)
        let g2_matches_map = self.g2_manager.scan_matches(&sequence_hashes, true);
        let mut g2_matches: Vec<_> = g2_matches_map.into_values().collect();

        // Sort by position to ensure G2Results are in position order
        // HashMap iteration order is arbitrary, so we must sort explicitly
        g2_matches.sort_by_key(|block| block.sequence_hash().position());

        // Hold the G2 blocks using BlockHolder (RAII semantics)
        self.held_g2_blocks = BlockHolder::new(g2_matches);

        // Send G2 results immediately (fire-and-forget) with parallel arrays
        let g2_sequence_hashes: Vec<SequenceHash> = self.held_g2_blocks.sequence_hashes();
        let g2_block_ids: Vec<BlockId> = self
            .held_g2_blocks
            .blocks()
            .iter()
            .map(|b| b.block_id())
            .collect();

        let g2_msg = OnboardMessage::G2Results {
            responder: self.instance_id,
            session_id: self.session_id,
            sequence_hashes: g2_sequence_hashes,
            block_ids: g2_block_ids,
        };
        self.transport.send(self.requester, g2_msg).await?;

        // Phase 2: Search G3 for remaining hashes (if G3 available)
        let g2_matched_hashes: HashSet<SequenceHash> =
            self.held_g2_blocks.sequence_hashes().into_iter().collect();

        let remaining_hashes: Vec<SequenceHash> = sequence_hashes
            .iter()
            .filter(|h| !g2_matched_hashes.contains(h))
            .copied()
            .collect();

        if !remaining_hashes.is_empty()
            && let Some(ref g3_manager) = self.g3_manager
        {
            // Use scan_matches instead of match_blocks to find all matching blocks
            // without stopping on first miss (supports partial sequence matching)
            let g3_matches_map = g3_manager.scan_matches(&remaining_hashes, true);
            let mut g3_matches: Vec<_> = g3_matches_map.into_values().collect();

            // Sort by position to ensure G3Results are in position order
            g3_matches.sort_by_key(|block| block.sequence_hash().position());

            if !g3_matches.is_empty() {
                // Hold the G3 blocks using BlockHolder
                self.held_g3_blocks = BlockHolder::new(g3_matches);

                // Send G3 results (sequence hashes only, keep order)
                let g3_sequence_hashes: Vec<SequenceHash> = self.held_g3_blocks.sequence_hashes();

                let g3_msg = OnboardMessage::G3Results {
                    responder: self.instance_id,
                    session_id: self.session_id,
                    sequence_hashes: g3_sequence_hashes,
                };
                self.transport.send(self.requester, g3_msg).await?;
            }
        }

        // Send SearchComplete to signal we're done searching
        let complete_msg = OnboardMessage::SearchComplete {
            responder: self.instance_id,
            session_id: self.session_id,
        };
        self.transport.send(self.requester, complete_msg).await?;

        // Phase 3: Process incoming messages
        while let Some(msg) = rx.recv().await {
            match msg {
                OnboardMessage::HoldBlocks {
                    hold_hashes,
                    drop_hashes: _,
                    ..
                } => {
                    // Filter by sequence hash - BlockHolder's retain keeps only matching hashes
                    self.held_g2_blocks.retain(&hold_hashes);
                    self.held_g3_blocks.retain(&hold_hashes);

                    // Send acknowledgment
                    let ack = OnboardMessage::Acknowledged {
                        responder: self.instance_id,
                        session_id: self.session_id,
                    };
                    self.transport.send(self.requester, ack).await?;

                    // Always wait for CloseSession, even if no G3 blocks
                    // This ensures proper session lifecycle and avoids race conditions
                    // where initiator sends CloseSession after we've already exited
                }

                OnboardMessage::StageBlocks { stage_hashes, .. } => {
                    // Filter G3 blocks to only keep blocks to be staged
                    // BlockHolder's retain keeps only matching hashes
                    self.held_g3_blocks.retain(&stage_hashes);

                    if !self.held_g3_blocks.is_empty() {
                        if self.parallel_worker.is_some() {
                            // Execute G3->G2 transfer
                            self.stage_g3_to_g2().await?;
                        } else {
                            tracing::warn!(
                                session_id = %self.session_id,
                                g3_blocks = self.held_g3_blocks.count(),
                                "G3 blocks cannot be staged: no parallel worker configured"
                            );
                        }
                    }

                    // Don't exit - wait for CloseSession in Hold/Prepare modes
                }

                OnboardMessage::ReleaseBlocks { release_hashes, .. } => {
                    // Release specific blocks by sequence hash
                    // BlockHolder's release removes blocks with given hashes
                    self.held_g2_blocks.release(&release_hashes);
                    self.held_g3_blocks.release(&release_hashes);
                }

                // todo: how does close session drop the session from the dashmap?
                // todo: do we need to handle this in the handler rather than the session responder loop?
                OnboardMessage::CloseSession { .. } => {
                    // Session complete - release all blocks and exit
                    // take_all() explicitly releases the blocks
                    let _ = self.held_g2_blocks.take_all();
                    let _ = self.held_g3_blocks.take_all();
                    break;
                }

                OnboardMessage::CreateSession { .. } => {
                    // Duplicate CreateSession - ignore
                }

                // todo: be explicit about what messages are expected and what messages are unexpected
                //       on the responder session - avoid using the wildcard match
                _ => {
                    // Unexpected message - log and ignore
                    tracing::warn!(
                        session_id = %self.session_id,
                        msg = ?msg,
                        "ResponderSession: unexpected message"
                    );
                }
            }

            // TODO: Add heartbeat/TTL timeout handling
            // If no message received within TTL duration:
            // - Release all held blocks
            // - Exit session
            // Implementation:
            //   tokio::select! {
            //       msg = rx.recv() => { /* process message */ }
            //       _ = tokio::time::sleep_until(ttl_deadline) => {
            //           eprintln!("Session {} TTL expired, releasing blocks", self.session_id);
            //           break;
            //       }
            //   }
        }

        Ok(())
    }

    /// Stage G3 blocks to G2.
    async fn stage_g3_to_g2(&mut self) -> Result<()> {
        let parallel_worker = self
            .parallel_worker
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("ParallelWorker required for G3->G2 staging"))?;

        let result = super::staging::stage_g3_to_g2(
            &self.held_g3_blocks,
            &self.g2_manager,
            &**parallel_worker,
        )
        .await?;

        // Extract sequence hashes and block IDs for newly staged blocks
        let new_sequence_hashes: Vec<SequenceHash> = result
            .new_g2_blocks
            .iter()
            .map(|b| b.sequence_hash())
            .collect();
        let new_block_ids: Vec<BlockId> =
            result.new_g2_blocks.iter().map(|b| b.block_id()).collect();

        // Release G3 blocks (take_all releases them) and hold new G2 blocks
        let _ = self.held_g3_blocks.take_all();
        self.held_g2_blocks.extend(result.new_g2_blocks);

        // Send BlocksReady with only newly staged blocks
        let ready_msg = OnboardMessage::BlocksReady {
            responder: self.instance_id,
            session_id: self.session_id,
            sequence_hashes: new_sequence_hashes,
            block_ids: new_block_ids,
        };
        self.transport.send(self.requester, ready_msg).await?;

        Ok(())
    }
}