clear_kv_blocks.rs 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::{service_v2, RouteDoc};
use axum::{http::Method, response::IntoResponse, routing::post, Json, Router};
use serde_json::json;
use std::sync::Arc;

9
use dynamo_runtime::{discovery::DiscoveryQuery, pipeline::PushRouter, stream::StreamExt};
10

11
12
pub const CLEAR_KV_ENDPOINT: &str = "clear_kv_blocks";

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
pub fn clear_kv_blocks_router(
    state: Arc<service_v2::State>,
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let path = path.unwrap_or_else(|| "/clear_kv_blocks".to_string());

    let docs: Vec<RouteDoc> = vec![RouteDoc::new(Method::POST, &path)];

    let router = Router::new()
        .route(&path, post(clear_kv_blocks_handler))
        .with_state(state);

    (docs, router)
}

async fn clear_kv_blocks_handler(
    axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
    let model_entries = state.manager().get_model_entries();

    // if there are no active workers
    if model_entries.is_empty() {
        return Json(serde_json::json!({
            "message": "No active worker groups found"
        }));
    }

    let distributed = match state.runtime() {
        Some(runtime) => runtime,
        None => {
            return Json(serde_json::json!({
                "message": "Failed to create distributed runtime",
            }));
        }
    };

    let mut cleared_workers = Vec::new();
    let mut failed_workers = Vec::new();

    // update cleared and failed workers
    let mut add_worker_result = |success: bool,
                                 name: String,
                                 status: &str,
                                 ns: &str,
                                 comp: &str,
                                 message: Option<String>| {
        let mut result = json!({
            "name": name,
61
            "endpoint": format!("{}/{}/{}", ns, comp, CLEAR_KV_ENDPOINT),
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            "status": status,
        });
        if success {
            if let Some(m) = message {
                result["response"] = json!(m);
            }
            cleared_workers.push(result);
        } else {
            if let Some(m) = message {
                result["error"] = json!(m);
            }
            failed_workers.push(result);
        }
    };

    // create client for each model entry
    for entry in &model_entries {
79
80
        let namespace = &entry.endpoint_id.namespace;
        let component = &entry.endpoint_id.component;
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        let entry_name = entry.name.to_string();

        tracing::debug!("Processing worker group: {}/{}", namespace, component);

        let namespace_obj = match distributed.namespace(namespace) {
            Ok(ns) => ns,
            Err(e) => {
                add_worker_result(
                    false,
                    entry_name,
                    "Failed to get namespace",
                    namespace,
                    component,
                    Some(e.to_string()),
                );
                continue;
            }
        };

        let component_obj = match namespace_obj.component(component) {
            Ok(comp) => comp,
            Err(e) => {
                add_worker_result(
                    false,
                    entry_name,
                    "Failed to get component",
                    namespace,
                    component,
                    Some(e.to_string()),
                );
                continue;
            }
        };

        let endpoint: dynamo_runtime::component::Endpoint =
116
            component_obj.endpoint(CLEAR_KV_ENDPOINT);
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

        let client = match endpoint.client().await {
            Ok(c) => c,
            Err(e) => {
                add_worker_result(
                    false,
                    entry_name,
                    "Failed to get client",
                    namespace,
                    component,
                    Some(e.to_string()),
                );
                continue;
            }
        };

        let router = match PushRouter::<(), serde_json::Value>::from_client(
            client.clone(),
            Default::default(),
        )
        .await
        {
            Ok(r) => r,
            Err(e) => {
                add_worker_result(
                    false,
                    entry_name,
                    "Failed to create router",
                    namespace,
                    component,
                    Some(e.to_string()),
                );
                continue;
            }
        };

153
154
155
156
157
158
159
160
        let discovery_client = distributed.discovery();
        let discovery_key = DiscoveryQuery::Endpoint {
            namespace: namespace.clone(),
            component: component.clone(),
            endpoint: CLEAR_KV_ENDPOINT.to_string(),
        };

        let discovery_instances = match discovery_client.list(discovery_key).await {
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            Ok(instances) => instances,
            Err(e) => {
                add_worker_result(
                    false,
                    entry_name,
                    "Failed to get instances for worker group",
                    namespace,
                    component,
                    Some(e.to_string()),
                );
                continue;
            }
        };

175
        if discovery_instances.is_empty() {
176
177
178
            add_worker_result(
                false,
                entry_name,
179
                "No instances found for clear_kv_blocks endpoint",
180
181
182
183
184
185
186
                namespace,
                component,
                None,
            );
            continue;
        }

187
        let instances_filtered: Vec<dynamo_runtime::component::Instance> = discovery_instances
188
            .into_iter()
189
190
191
192
193
            .filter_map(|di| match di {
                dynamo_runtime::discovery::DiscoveryInstance::Endpoint(instance) => Some(instance),
                _ => None,
            })
            .collect();
194
195
196

        for instance in &instances_filtered {
            let instance_name = format!("{}-instance-{}", entry.name, instance.id());
197
            match router.direct(().into(), instance.id()).await {
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                Ok(mut stream) => match stream.next().await {
                    Some(response) => {
                        add_worker_result(
                            true,
                            instance_name,
                            "Successfully cleared kv blocks for instance",
                            namespace,
                            component,
                            Some(response.to_string()),
                        );
                    }
                    None => {
                        add_worker_result(
                            false,
                            instance_name,
                            "No response from instance",
                            namespace,
                            component,
                            None,
                        );
                    }
                },
                Err(e) => {
                    add_worker_result(
                        false,
                        instance_name,
                        "Failed to send request for instance",
                        namespace,
                        component,
                        Some(e.to_string()),
                    );
                }
            }
        }
    }

    Json(serde_json::json!({
        "cleared_workers": cleared_workers,
        "failed_workers": failed_workers
    }))
}