server.rs 8.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::collections::HashMap;
use std::sync::Arc;

use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};

use dynamo_kv_router::protocols::{LocalBlockHash, WorkerId, compute_block_hash_for_seq};

16
use super::registry::{IndexerKey, WorkerRegistry};
17
18
19

pub struct AppState {
    pub registry: WorkerRegistry,
20
21
22
23
}

fn default_tenant() -> String {
    "default".to_string()
24
25
26
27
28
29
}

#[derive(Deserialize)]
pub struct RegisterRequest {
    pub instance_id: WorkerId,
    pub endpoint: String,
30
31
32
33
    pub model_name: String,
    #[serde(default = "default_tenant")]
    pub tenant_id: String,
    pub block_size: u32,
34
35
36
37
38
39
40
    #[serde(default)]
    pub dp_rank: Option<u32>,
}

#[derive(Deserialize)]
pub struct UnregisterRequest {
    pub instance_id: WorkerId,
41
42
43
    pub model_name: String,
    #[serde(default)]
    pub tenant_id: Option<String>,
44
45
46
47
48
49
50
51
52
53
54
55
56
    #[serde(default)]
    pub dp_rank: Option<u32>,
}

#[derive(Serialize)]
struct WorkerInfo {
    instance_id: WorkerId,
    endpoints: HashMap<u32, String>,
}

#[derive(Deserialize)]
pub struct QueryRequest {
    pub token_ids: Vec<u32>,
57
58
59
    pub model_name: String,
    #[serde(default = "default_tenant")]
    pub tenant_id: String,
60
61
62
63
    #[serde(default)]
    pub lora_name: Option<String>,
}

64
65
66
67
68
/// Query using pre-computed block hashes.
///
/// Callers must include the LoRA salt in their hashes when applicable — use
/// [`compute_block_hash_for_seq`] with the appropriate `lora_name`. The indexer
/// cannot retroactively apply a LoRA salt to pre-computed hashes.
69
70
71
#[derive(Deserialize)]
pub struct QueryByHashRequest {
    pub block_hashes: Vec<i64>,
72
73
74
    pub model_name: String,
    #[serde(default = "default_tenant")]
    pub tenant_id: String,
75
76
77
78
79
80
81
82
83
84
85
86
87
}

#[derive(Serialize)]
struct ScoreResponse {
    scores: HashMap<String, HashMap<String, u32>>,
    frequencies: Vec<usize>,
    tree_sizes: HashMap<String, HashMap<String, usize>>,
}

async fn register(
    State(state): State<Arc<AppState>>,
    Json(req): Json<RegisterRequest>,
) -> impl IntoResponse {
88
89
90
91
92
93
94
95
    match state.registry.register(
        req.instance_id,
        req.endpoint,
        req.dp_rank.unwrap_or(0),
        req.model_name,
        req.tenant_id,
        req.block_size,
    ) {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        Ok(()) => (
            StatusCode::CREATED,
            Json(serde_json::json!({"status": "ok"})),
        ),
        Err(e) => (
            StatusCode::CONFLICT,
            Json(serde_json::json!({"error": e.to_string()})),
        ),
    }
}

async fn unregister(
    State(state): State<Arc<AppState>>,
    Json(req): Json<UnregisterRequest>,
) -> impl IntoResponse {
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    let result = match req.tenant_id {
        Some(tenant_id) => match req.dp_rank {
            Some(dp_rank) => {
                state
                    .registry
                    .deregister_dp_rank(req.instance_id, dp_rank, &req.model_name, &tenant_id)
                    .await
            }
            None => {
                state
                    .registry
                    .deregister(req.instance_id, &req.model_name, &tenant_id)
                    .await
            }
        },
        None => {
127
128
            state
                .registry
129
                .deregister_all_tenants(req.instance_id, &req.model_name)
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                .await
        }
    };
    match result {
        Ok(()) => (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))),
        Err(e) => (
            StatusCode::NOT_FOUND,
            Json(serde_json::json!({"error": e.to_string()})),
        ),
    }
}

async fn list_workers(State(state): State<Arc<AppState>>) -> impl IntoResponse {
    let workers: Vec<WorkerInfo> = state
        .registry
        .list()
        .into_iter()
        .map(|(instance_id, endpoints)| WorkerInfo {
            instance_id,
            endpoints,
        })
        .collect();
    Json(workers)
}

155
156
157
158
fn build_score_response(
    overlap: dynamo_kv_router::protocols::OverlapScores,
    block_size: u32,
) -> ScoreResponse {
159
160
161
162
163
    let mut scores: HashMap<String, HashMap<String, u32>> = HashMap::new();
    for (k, v) in &overlap.scores {
        scores
            .entry(k.worker_id.to_string())
            .or_default()
164
            .insert(k.dp_rank.to_string(), v * block_size);
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    }
    let mut tree_sizes: HashMap<String, HashMap<String, usize>> = HashMap::new();
    for (k, v) in &overlap.tree_sizes {
        tree_sizes
            .entry(k.worker_id.to_string())
            .or_default()
            .insert(k.dp_rank.to_string(), *v);
    }
    ScoreResponse {
        scores,
        frequencies: overlap.frequencies,
        tree_sizes,
    }
}

async fn query(
    State(state): State<Arc<AppState>>,
    Json(req): Json<QueryRequest>,
) -> impl IntoResponse {
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    let key = IndexerKey {
        model_name: req.model_name,
        tenant_id: req.tenant_id,
    };
    let Some(ie) = state.registry.get_indexer(&key) else {
        return (
            StatusCode::NOT_FOUND,
            Json(serde_json::json!({
                "error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
            })),
        );
    };
    let block_size = ie.block_size;
    let indexer = ie.indexer.clone();
    drop(ie);

    let block_hashes =
        compute_block_hash_for_seq(&req.token_ids, block_size, None, req.lora_name.as_deref());
    match indexer.find_matches(block_hashes).await {
203
204
        Ok(overlap) => (
            StatusCode::OK,
205
            Json(serde_json::json!(build_score_response(overlap, block_size))),
206
207
208
209
210
211
212
213
214
215
216
217
        ),
        Err(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(serde_json::json!({"error": e.to_string()})),
        ),
    }
}

async fn query_by_hash(
    State(state): State<Arc<AppState>>,
    Json(req): Json<QueryByHashRequest>,
) -> impl IntoResponse {
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    let key = IndexerKey {
        model_name: req.model_name,
        tenant_id: req.tenant_id,
    };
    let Some(ie) = state.registry.get_indexer(&key) else {
        return (
            StatusCode::NOT_FOUND,
            Json(serde_json::json!({
                "error": format!("no indexer for model={} tenant={}", key.model_name, key.tenant_id)
            })),
        );
    };
    let block_size = ie.block_size;
    let indexer = ie.indexer.clone();
    drop(ie);

234
235
236
237
238
    let block_hashes: Vec<LocalBlockHash> = req
        .block_hashes
        .iter()
        .map(|h| LocalBlockHash(*h as u64))
        .collect();
239
    match indexer.find_matches(block_hashes).await {
240
241
        Ok(overlap) => (
            StatusCode::OK,
242
            Json(serde_json::json!(build_score_response(overlap, block_size))),
243
244
245
246
247
248
249
250
251
        ),
        Err(e) => (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(serde_json::json!({"error": e.to_string()})),
        ),
    }
}

async fn dump_events(State(state): State<Arc<AppState>>) -> impl IntoResponse {
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    let indexers = state.registry.all_indexers();
    let mut handles = Vec::with_capacity(indexers.len());

    for (key, indexer) in indexers {
        handles.push(tokio::spawn(async move {
            let events = indexer.dump_events().await;
            (key, events)
        }));
    }

    let mut result: HashMap<String, serde_json::Value> = HashMap::new();
    for handle in handles {
        match handle.await {
            Ok((key, Ok(events))) => {
                let map_key = format!("{}:{}", key.model_name, key.tenant_id);
                result.insert(map_key, serde_json::json!(events));
            }
            Ok((key, Err(e))) => {
                let map_key = format!("{}:{}", key.model_name, key.tenant_id);
                result.insert(map_key, serde_json::json!({"error": e.to_string()}));
            }
            Err(e) => {
                tracing::warn!("dump task join error: {e}");
            }
        }
277
    }
278
    (StatusCode::OK, Json(serde_json::json!(result)))
279
280
281
282
283
284
285
286
287
288
289
290
}

pub fn create_router(state: Arc<AppState>) -> Router {
    Router::new()
        .route("/register", post(register))
        .route("/unregister", post(unregister))
        .route("/workers", get(list_workers))
        .route("/query", post(query))
        .route("/query_by_hash", post(query_by_hash))
        .route("/dump", get(dump_events))
        .with_state(state)
}