state.rs 10.1 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! LeaderState - Coordination layer for managing workers.
//!
//! This module provides the leader's coordination state, including:
//! - Worker registration and rank mapping
//! - Remote leader tracking for cross-leader transfers
//! - Routing strategies for asymmetric TP configurations

use std::collections::HashMap;
use std::sync::{Arc, RwLock};

use ::velo::Messenger;
use anyhow::Result;

use crate::InstanceId;
use crate::worker::{CoordinatedWorker, Worker};
use kvbm_physical::manager::SerializedLayout;

/// Info about a remote leader and its workers.
#[derive(Debug)]
pub struct RemoteLeaderInfo {
    /// Instance ID of the remote leader process
    pub instance_id: InstanceId,
    /// Number of workers under the remote leader
    pub worker_count: usize,
    /// Cached metadata from remote workers (rank-ordered)
    pub worker_metadata: Vec<SerializedLayout>,
}

/// Leader coordination state - owns workers and routing logic.
///
/// LeaderState manages:
/// - Registration of workers during handshake phase
/// - Coordination with remote leaders for cross-leader transfers
/// - Routing strategies for asymmetric TP configurations
pub struct LeaderState {
    /// This leader's instance ID
    instance_id: InstanceId,

    /// Nova runtime for RPC
    messenger: Arc<Messenger>,

    /// Workers under this leader (rank-ordered)
    workers: Vec<CoordinatedWorker>,

    /// Known remote leaders (by their instance ID)
    remote_leaders: RwLock<HashMap<InstanceId, RemoteLeaderInfo>>,
}

impl LeaderState {
    /// Create a new LeaderState.
    ///
    /// # Arguments
    /// * `instance_id` - This leader's unique identifier
    /// * `nova` - Nova runtime for RPC communication
    pub fn new(instance_id: InstanceId, messenger: Arc<Messenger>) -> Self {
        Self {
            instance_id,
            messenger,
            workers: Vec::new(),
            remote_leaders: RwLock::new(HashMap::new()),
        }
    }

    /// Get this leader's instance ID.
    pub fn instance_id(&self) -> InstanceId {
        self.instance_id
    }

    /// Get the Nova runtime.
    pub fn nova(&self) -> &Arc<Messenger> {
        &self.messenger
    }

    /// Register a worker during the handshake phase.
    ///
    /// Workers should be registered in rank order (0, 1, 2, ...).
    ///
    /// # Arguments
    /// * `rank` - The worker's rank (0-indexed)
    /// * `host_instance` - Instance ID of the process hosting this worker
    /// * `worker` - The Worker implementation (DirectWorker or VeloWorkerClient)
    pub fn register_worker(
        &mut self,
        rank: usize,
        host_instance: InstanceId,
        worker: Box<dyn Worker>,
    ) {
        let coordinated = CoordinatedWorker::new(worker, rank, host_instance);

        // Ensure rank-ordered insertion
        if rank == self.workers.len() {
            // Sequential append (expected path)
            self.workers.push(coordinated);
        } else if rank < self.workers.len() {
            // Re-registration or out-of-order within existing range
            self.workers[rank] = coordinated;
        } else {
            panic!(
                "Gap in worker ranks: rank {} but only {} workers registered",
                rank,
                self.workers.len()
            );
        }
    }

    /// Number of workers under this leader.
    pub fn worker_count(&self) -> usize {
        self.workers.len()
    }

    /// Get a worker by rank.
    pub fn worker(&self, rank: usize) -> Option<&CoordinatedWorker> {
        self.workers.get(rank)
    }

    /// Get a mutable worker by rank.
    pub fn worker_mut(&mut self, rank: usize) -> Option<&mut CoordinatedWorker> {
        self.workers.get_mut(rank)
    }

    /// Iterate over all workers.
    pub fn workers(&self) -> impl Iterator<Item = &CoordinatedWorker> {
        self.workers.iter()
    }

    /// Connect to a remote leader and distribute its worker metadata to our workers.
    ///
    /// This implements the routing strategy for cross-leader transfers:
    /// - 1:1 mapping when TP sizes match
    /// - Many-to-one when local TP > remote TP
    /// - One-to-many when local TP < remote TP
    ///
    /// # Arguments
    /// * `remote_leader_id` - Instance ID of the remote leader
    /// * `remote_metadata` - Metadata from each remote worker (rank-ordered)
    pub async fn import_remote_leader(
        &self,
        remote_leader_id: InstanceId,
        remote_metadata: Vec<SerializedLayout>,
    ) -> Result<()> {
        let remote_count = remote_metadata.len();
        let local_count = self.workers.len();

        tracing::info!(
            local_count,
            remote_count,
            %remote_leader_id,
            "Importing remote leader metadata"
        );

        // Store remote leader info
        {
            let mut leaders = self.remote_leaders.write().unwrap();
            leaders.insert(
                remote_leader_id,
                RemoteLeaderInfo {
                    instance_id: remote_leader_id,
                    worker_count: remote_count,
                    worker_metadata: remote_metadata.clone(),
                },
            );
        }

        // Distribute metadata based on routing strategy
        for (local_rank, worker) in self.workers.iter().enumerate() {
            let target_remote_ranks = route_local_to_remote(local_rank, local_count, remote_count);

            for remote_rank in target_remote_ranks {
                tracing::debug!(
                    local_rank,
                    remote_rank,
                    %remote_leader_id,
                    "Importing remote metadata for local worker"
                );

                worker
                    .import_remote_metadata(
                        remote_leader_id,
                        remote_rank,
                        remote_metadata[remote_rank].clone(),
                    )
                    .await?;
            }
        }

        Ok(())
    }

    /// Export this leader's workers' metadata for another leader to import.
    ///
    /// Returns metadata from each worker in rank order.
    pub async fn export_worker_metadata(&self) -> Result<Vec<SerializedLayout>> {
        let mut metadata = Vec::with_capacity(self.workers.len());

        for worker in &self.workers {
            let response = worker.inner().export_metadata()?;
            metadata.push(response.await?);
        }

        Ok(metadata)
    }

    /// Check if we have imported metadata from a remote leader.
    pub fn has_remote_leader(&self, remote_leader_id: InstanceId) -> bool {
        self.remote_leaders
            .read()
            .unwrap()
            .contains_key(&remote_leader_id)
    }

    /// Get info about a remote leader if known.
    pub fn remote_leader_info(&self, remote_leader_id: InstanceId) -> Option<RemoteLeaderInfo> {
        self.remote_leaders
            .read()
            .unwrap()
            .get(&remote_leader_id)
            .map(|info| RemoteLeaderInfo {
                instance_id: info.instance_id,
                worker_count: info.worker_count,
                worker_metadata: info.worker_metadata.clone(),
            })
    }
}

/// Routing strategy: which local ranks receive from which remote ranks.
///
/// This function determines how metadata/transfers are routed when
/// the local and remote TP sizes differ.
///
/// # Examples
/// - TP=4 local, TP=4 remote: 1:1 mapping (rank 0→0, 1→1, 2→2, 3→3)
/// - TP=4 local, TP=2 remote: 0→0, 1→0, 2→1, 3→1 (many-to-one)
/// - TP=2 local, TP=4 remote: 0→\[0,1\], 1→\[2,3\] (one-to-many)
pub fn route_local_to_remote(
    local_rank: usize,
    local_count: usize,
    remote_count: usize,
) -> Vec<usize> {
    if local_count == remote_count {
        // 1:1 mapping
        vec![local_rank]
    } else if local_count > remote_count {
        // Many local → few remote: multiple locals share a remote
        vec![local_rank % remote_count]
    } else {
        // Few local → many remote: each local gets multiple remotes
        let remotes_per_local = remote_count / local_count;
        let start = local_rank * remotes_per_local;
        // Last local rank absorbs any remainder from non-divisible ratios
        let end = if local_rank == local_count - 1 {
            remote_count
        } else {
            start + remotes_per_local
        };
        (start..end).collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_route_1_to_1() {
        // Same TP size
        assert_eq!(route_local_to_remote(0, 4, 4), vec![0]);
        assert_eq!(route_local_to_remote(1, 4, 4), vec![1]);
        assert_eq!(route_local_to_remote(2, 4, 4), vec![2]);
        assert_eq!(route_local_to_remote(3, 4, 4), vec![3]);
    }

    #[test]
    fn test_route_many_to_one() {
        // Local TP=4, Remote TP=2
        assert_eq!(route_local_to_remote(0, 4, 2), vec![0]);
        assert_eq!(route_local_to_remote(1, 4, 2), vec![1]);
        assert_eq!(route_local_to_remote(2, 4, 2), vec![0]);
        assert_eq!(route_local_to_remote(3, 4, 2), vec![1]);
    }

    #[test]
    fn test_route_one_to_many() {
        // Local TP=2, Remote TP=4
        assert_eq!(route_local_to_remote(0, 2, 4), vec![0, 1]);
        assert_eq!(route_local_to_remote(1, 2, 4), vec![2, 3]);
    }

    #[test]
    fn test_route_4_to_8() {
        // Local TP=4, Remote TP=8
        assert_eq!(route_local_to_remote(0, 4, 8), vec![0, 1]);
        assert_eq!(route_local_to_remote(1, 4, 8), vec![2, 3]);
        assert_eq!(route_local_to_remote(2, 4, 8), vec![4, 5]);
        assert_eq!(route_local_to_remote(3, 4, 8), vec![6, 7]);
    }

    #[test]
    fn test_route_non_divisible_remainder() {
        // Local TP=2, Remote TP=5: last local rank absorbs remainder
        assert_eq!(route_local_to_remote(0, 2, 5), vec![0, 1]);
        assert_eq!(route_local_to_remote(1, 2, 5), vec![2, 3, 4]);

        // Local TP=3, Remote TP=7: last rank gets extras
        assert_eq!(route_local_to_remote(0, 3, 7), vec![0, 1]);
        assert_eq!(route_local_to_remote(1, 3, 7), vec![2, 3]);
        assert_eq!(route_local_to_remote(2, 3, 7), vec![4, 5, 6]);
    }
}