main.rs 5.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::sync::Arc;

use clap::Parser;
use tokio::net::TcpListener;

mod indexer;
mod listener;
11
mod recovery;
12
13
14
15
16
17
18
19
20
mod registry;
mod server;

use registry::WorkerRegistry;
use server::{AppState, create_router};

#[derive(Parser)]
#[command(name = "dynamo-kv-indexer", about = "Standalone KV cache indexer")]
struct Cli {
21
    /// KV cache block size for initial workers registered via --workers
22
    #[arg(long)]
23
    block_size: Option<u32>,
24
25
26
27
28
29

    /// HTTP server port
    #[arg(long, default_value_t = 8090)]
    port: u16,

    /// Number of indexer threads (1 = single-threaded KvIndexer, >1 = ThreadPoolIndexer)
30
    #[arg(long, default_value_t = 4)]
31
32
    threads: usize,

33
    /// Initial workers as "worker_id[:dp_rank]=zmq_address,..." (e.g. "1=tcp://host:5557,1:1=tcp://host:5558")
34
35
    #[arg(long)]
    workers: Option<String>,
36
37
38
39
40
41
42
43

    /// Model name for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    model_name: String,

    /// Tenant ID for initial workers registered via --workers
    #[arg(long, default_value = "default")]
    tenant_id: String,
44
45
46
47

    /// Comma-separated peer URLs for P2P recovery (e.g. "http://host1:8090,http://host2:8091")
    #[arg(long)]
    peers: Option<String>,
48
49
}

50
fn parse_workers(s: &str) -> Vec<(u64, u32, String)> {
51
52
53
    s.split(',')
        .filter(|entry| !entry.is_empty())
        .filter_map(|entry| {
54
55
56
57
58
59
60
61
            let (id_part, addr) = entry.split_once('=')?;
            let id_part = id_part.trim();
            let (id, dp_rank) = if let Some((id_str, rank_str)) = id_part.split_once(':') {
                (id_str.parse::<u64>().ok()?, rank_str.parse::<u32>().ok()?)
            } else {
                (id_part.parse::<u64>().ok()?, 0)
            };
            Some((id, dp_rank, addr.trim().to_string()))
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        })
        .collect()
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
        )
        .init();

    let cli = Cli::parse();

77
78
79
80
81
82
83
84
85
86
87
    let peers: Vec<String> = cli
        .peers
        .as_deref()
        .map(|s| {
            s.split(',')
                .filter(|p| !p.is_empty())
                .map(|p| p.trim().to_string())
                .collect()
        })
        .unwrap_or_default();

88
    tracing::info!(
89
        block_size = ?cli.block_size,
90
91
        port = cli.port,
        threads = cli.threads,
92
93
        model_name = %cli.model_name,
        tenant_id = %cli.tenant_id,
94
        num_peers = peers.len(),
95
96
97
        "Starting standalone KV cache indexer"
    );

98
    let registry = WorkerRegistry::new(cli.threads);
99

100
101
102
    // Register initial workers — connects ZMQ sockets but listeners wait
    // for the ready signal. This ensures ZMQ subscription handshakes begin
    // before P2P recovery fetches the dump from a peer.
103
    if let Some(ref workers_str) = cli.workers {
104
105
106
        let block_size = cli.block_size.ok_or_else(|| {
            anyhow::anyhow!("--block-size is required when --workers is specified")
        })?;
107
108
        for (instance_id, dp_rank, endpoint) in parse_workers(workers_str) {
            tracing::info!(instance_id, dp_rank, endpoint, "Registering initial worker");
109
110
111
            registry.register(
                instance_id,
                endpoint,
112
                dp_rank,
113
114
115
116
                cli.model_name.clone(),
                cli.tenant_id.clone(),
                block_size,
            )?;
117
118
119
        }
    }

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    // P2P recovery: fetch dump from a peer before starting ZMQ listeners.
    // The 1s delay inside recover_from_peers ensures the peer's tree has
    // advanced past our ZMQ connection floor before we fetch the dump.
    if !peers.is_empty() {
        match recovery::recover_from_peers(&peers, &registry).await {
            Ok(true) => tracing::info!("P2P recovery completed"),
            Ok(false) => tracing::warn!("no reachable peers, starting with empty state"),
            Err(e) => tracing::warn!(error = %e, "P2P recovery failed, starting with empty state"),
        }
        for peer in &peers {
            registry.register_peer(peer.clone());
        }
    }

    // Signal ready — unblocks all ZMQ listeners to start draining buffered events
    registry.signal_ready();

137
    let state = Arc::new(AppState { registry });
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    let app = create_router(state);
    let listener = TcpListener::bind(("0.0.0.0", cli.port)).await?;
    tracing::info!("HTTP server listening on 0.0.0.0:{}", cli.port);
    axum::serve(listener, app).await?;

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_workers() {
153
        let input = "1=tcp://host:5557,2:1=tcp://host:5558";
154
155
        let result = parse_workers(input);
        assert_eq!(result.len(), 2);
156
157
        assert_eq!(result[0], (1, 0, "tcp://host:5557".to_string()));
        assert_eq!(result[1], (2, 1, "tcp://host:5558".to_string()));
158
159
160
161
162
163
164
    }

    #[test]
    fn test_parse_workers_empty() {
        assert!(parse_workers("").is_empty());
    }
}