"launch/vscode:/vscode.git/clone" did not exist on "43991e76fd2bd0348aaf2b949c64775146f556f6"
http.rs 14 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_card,
14
    namespace::is_global_namespace,
15
16
17
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
18
19
    },
};
20
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
21
use dynamo_runtime::{DistributedRuntime, Runtime};
22
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
23
24

/// Build and run an HTTP service
25
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
49
50
51
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
52
53
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
54

55
56
57
58
59
60
61
62
63
    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Pass the custom backend metrics endpoint as-is (already in namespace.component.endpoint format)
    http_service_builder = http_service_builder.with_custom_backend_config(
        local_model
            .custom_backend_metrics_endpoint()
            .map(|s| s.to_string()),
        local_model.custom_backend_metrics_polling_interval(),
    );

64
    let http_service = match engine_config {
65
        EngineConfig::Dynamic(_) => {
66
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
67
            // This allows the /health endpoint to query store for active instances
68
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
69
            let http_service = http_service_builder.build()?;
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            let store = Arc::new(distributed_runtime.store().clone());

            let router_config = engine_config.local_model().router_config();
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
            let namespace = engine_config.local_model().namespace().unwrap_or("");
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
                distributed_runtime,
                http_service.state().manager_clone(),
                store,
                router_config.router_mode,
                Some(router_config.kv_router_config),
                router_config.busy_threshold,
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
            )
            .await?;
94
            http_service
95
        }
96
97
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
98
99
            let checksum = card.mdcsum();

100
101
            let router_mode = local_model.router_config().router_mode;

102
            let dst_config = DistributedConfig::from_settings(true); // true means static
103
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
104
            let http_service = http_service_builder.build()?;
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

127
            let tokenizer_hf = card.tokenizer_hf()?;
128
129
130
            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
131
132
133
134
135
136
137
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser.clone(),
                tokenizer_hf.clone(),
138
                None, // No prefill chooser in http static mode
139
            )
140
            .await?;
141
142
143
144
145
            manager.add_chat_completions_model(
                local_model.display_name(),
                checksum,
                chat_engine,
            )?;
146

147
148
149
150
151
152
153
154
155
156
157
158
159
            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser,
                tokenizer_hf,
                None, // No prefill chooser in http static mode
            )
            .await?;
160
161
162
163
164
            manager.add_completions_model(
                local_model.display_name(),
                checksum,
                completions_engine,
            )?;
165

166
            for endpoint_type in EndpointType::all() {
167
                http_service.enable_model_endpoint(endpoint_type, true);
168
169
            }

170
            http_service
171
172
        }
        EngineConfig::StaticFull { engine, model, .. } => {
173
            let http_service = http_service_builder.build()?;
174
175
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
176
177
178
            let checksum = model.card().mdcsum();
            manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
179
180
181

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
182
                http_service.enable_model_endpoint(endpoint_type, true);
183
            }
184
            http_service
185
        }
186
187
        EngineConfig::StaticCore {
            engine: inner_engine,
188
            model,
189
            ..
190
        } => {
191
            let http_service = http_service_builder.build()?;
192
            let manager = http_service.model_manager();
193
            let checksum = model.card().mdcsum();
194

195
196
197
198
199
200
201
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
202
            manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
203

204
205
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
206
                NvCreateCompletionResponse,
207
            >(model.card(), inner_engine, tokenizer_hf)
208
            .await?;
209
            manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
210
211
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
212
                http_service.enable_model_endpoint(endpoint_type, true);
213
            }
214
            http_service
215
        }
216
    };
217
218
219
220
221
222
223
224
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
225
226
227
228
229
230
231
232
233
234
235
236

    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Start custom backend metrics polling if configured
    let polling_task =
        if let (Some(namespace_component_endpoint), Some(polling_interval), Some(registry)) = (
            http_service
                .custom_backend_namespace_component_endpoint
                .as_ref(),
            http_service.custom_backend_metrics_polling_interval,
            http_service.custom_backend_registry.as_ref(),
        ) {
            // Create DistributedRuntime for polling, matching the engine's mode
237
            let drt = DistributedRuntime::from_settings(runtime.clone()).await?;
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            tracing::info!(
                namespace_component_endpoint=%namespace_component_endpoint,
                polling_interval_secs=polling_interval,
                "Starting custom backend metrics polling task"
            );
            // Spawn the polling task and keep the JoinHandle alive so it can be aborted during
            // shutdown. While graceful shutdown is not strictly necessary for this non-critical
            // metrics polling, explicitly aborting it prevents the task from running during the
            // shutdown phase.
            Some(
                crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
                    drt,
                    namespace_component_endpoint.clone(),
                    polling_interval,
                    registry.clone(),
                ),
            )
        } else {
            None
        };

259
    http_service.run(runtime.primary_token()).await?;
260
261
262
263
264
265

    // Abort the polling task if it was started
    if let Some(task) = polling_task {
        task.abort();
    }

266
267
    runtime.shutdown(); // Cancel primary token
    Ok(())
268
}
269

270
/// Spawns a task that watches for new models in store,
271
/// and registers them with the ModelManager so that the HTTP service can use them.
272
#[allow(clippy::too_many_arguments)]
273
async fn run_watcher(
274
    runtime: DistributedRuntime,
275
    model_manager: Arc<ModelManager>,
276
    store: Arc<KeyValueStoreManager>,
277
    router_mode: RouterMode,
278
    kv_router_config: Option<KvRouterConfig>,
279
    busy_threshold: Option<f64>,
280
    target_namespace: Option<String>,
281
    http_service: Arc<HttpService>,
282
    metrics: Arc<crate::http::service::metrics::Metrics>,
283
) -> anyhow::Result<()> {
284
    let cancellation_token = runtime.primary_token();
285
286
287
288
289
290
291
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
292
293
    tracing::debug!("Waiting for remote model");
    let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
294
295
296
297
298

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

299
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
300
    let _endpoint_enabler_task = tokio::spawn(async move {
301
302
303
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
304
305
306
307
        }
    });

    // Pass the sender to the watcher
308
    let _watcher_task = tokio::spawn(async move {
309
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
310
    });
311

312
313
    Ok(())
}
314
315

/// Updates HTTP service endpoints based on available model types
316
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
317
318
319
320
321
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
322
        ModelUpdate::Added(card) => {
323
            // Handle all supported endpoint types, not just the first one
324
            for endpoint_type in card.model_type.as_endpoint_types() {
325
                service.enable_model_endpoint(endpoint_type, true);
326
            }
327
        }
328
        ModelUpdate::Removed(card) => {
329
            // Handle all supported endpoint types, not just the first one
330
            for endpoint_type in card.model_type.as_endpoint_types() {
331
                service.enable_model_endpoint(endpoint_type, false);
332
            }
333
        }
334
335
    }
}
336
337

/// Updates metrics for model type changes
338
fn update_model_metrics(
339
340
341
342
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
343
344
345
346
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
347
348
            }
        }
349
350
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
351
352
353
354
355
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}