http.rs 13.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_card,
14
    namespace::is_global_namespace,
15
16
17
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
18
19
    },
};
20
21
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
22
use dynamo_runtime::storage::key_value_store::KeyValueStoreManager;
23
24

/// Build and run an HTTP service
25
26
27
28
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
52
53
54
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
55
56
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
57

58
59
60
61
62
63
64
65
66
    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Pass the custom backend metrics endpoint as-is (already in namespace.component.endpoint format)
    http_service_builder = http_service_builder.with_custom_backend_config(
        local_model
            .custom_backend_metrics_endpoint()
            .map(|s| s.to_string()),
        local_model.custom_backend_metrics_polling_interval(),
    );

67
    let http_service = match engine_config {
68
        EngineConfig::Dynamic(_) => {
69
            // This allows the /health endpoint to query store for active instances
70
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
71
            let http_service = http_service_builder.build()?;
72
73
74
75
76
77
78
79
80
81
82
83
84
            let store = Arc::new(distributed_runtime.store().clone());

            let router_config = engine_config.local_model().router_config();
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
            let namespace = engine_config.local_model().namespace().unwrap_or("");
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
85
                distributed_runtime.clone(),
86
87
88
89
90
91
92
93
94
95
                http_service.state().manager_clone(),
                store,
                router_config.router_mode,
                Some(router_config.kv_router_config),
                router_config.busy_threshold,
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
            )
            .await?;
96
            http_service
97
        }
98
99
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
100
            let checksum = card.mdcsum();
101
            let router_mode = local_model.router_config().router_mode;
102
            let http_service = http_service_builder.build()?;
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

125
            let tokenizer_hf = card.tokenizer_hf()?;
126
127
128
            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
129
130
131
132
133
134
135
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser.clone(),
                tokenizer_hf.clone(),
136
                None, // No prefill chooser in http static mode
137
            )
138
            .await?;
139
140
141
142
143
            manager.add_chat_completions_model(
                local_model.display_name(),
                checksum,
                chat_engine,
            )?;
144

145
146
147
148
149
150
151
152
153
154
155
156
157
            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser,
                tokenizer_hf,
                None, // No prefill chooser in http static mode
            )
            .await?;
158
159
160
161
162
            manager.add_completions_model(
                local_model.display_name(),
                checksum,
                completions_engine,
            )?;
163

164
            for endpoint_type in EndpointType::all() {
165
                http_service.enable_model_endpoint(endpoint_type, true);
166
167
            }

168
            http_service
169
170
        }
        EngineConfig::StaticFull { engine, model, .. } => {
171
            let http_service = http_service_builder.build()?;
172
173
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
174
            let checksum = model.card().mdcsum();
175
176
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
177
178
179

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
180
                http_service.enable_model_endpoint(endpoint_type, true);
181
            }
182
            http_service
183
        }
184
185
        EngineConfig::StaticCore {
            engine: inner_engine,
186
            model,
187
            ..
188
        } => {
189
            let http_service = http_service_builder.build()?;
190
            let manager = http_service.model_manager();
191
            let checksum = model.card().mdcsum();
192

193
194
195
196
197
198
199
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
200
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
201

202
203
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
204
                NvCreateCompletionResponse,
205
            >(model.card(), inner_engine, tokenizer_hf)
206
            .await?;
207
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
208
209
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
210
                http_service.enable_model_endpoint(endpoint_type, true);
211
            }
212
            http_service
213
        }
214
    };
215
216
217
218
219
220
221
222
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Start custom backend metrics polling if configured
    let polling_task =
        if let (Some(namespace_component_endpoint), Some(polling_interval), Some(registry)) = (
            http_service
                .custom_backend_namespace_component_endpoint
                .as_ref(),
            http_service.custom_backend_metrics_polling_interval,
            http_service.custom_backend_registry.as_ref(),
        ) {
            tracing::info!(
                namespace_component_endpoint=%namespace_component_endpoint,
                polling_interval_secs=polling_interval,
                "Starting custom backend metrics polling task"
            );
            // Spawn the polling task and keep the JoinHandle alive so it can be aborted during
            // shutdown. While graceful shutdown is not strictly necessary for this non-critical
            // metrics polling, explicitly aborting it prevents the task from running during the
            // shutdown phase.
            Some(
                crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
245
                    distributed_runtime.clone(),
246
247
248
249
250
251
252
253
254
                    namespace_component_endpoint.clone(),
                    polling_interval,
                    registry.clone(),
                ),
            )
        } else {
            None
        };

255
256
257
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
258
259
260
261
262
263

    // Abort the polling task if it was started
    if let Some(task) = polling_task {
        task.abort();
    }

264
    distributed_runtime.shutdown(); // Cancel primary token
265
    Ok(())
266
}
267

268
/// Spawns a task that watches for new models in store,
269
/// and registers them with the ModelManager so that the HTTP service can use them.
270
#[allow(clippy::too_many_arguments)]
271
async fn run_watcher(
272
    runtime: DistributedRuntime,
273
    model_manager: Arc<ModelManager>,
274
    store: Arc<KeyValueStoreManager>,
275
    router_mode: RouterMode,
276
    kv_router_config: Option<KvRouterConfig>,
277
    busy_threshold: Option<f64>,
278
    target_namespace: Option<String>,
279
    http_service: Arc<HttpService>,
280
    metrics: Arc<crate::http::service::metrics::Metrics>,
281
) -> anyhow::Result<()> {
282
    let cancellation_token = runtime.primary_token();
283
284
285
286
287
288
289
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
290
291
    tracing::debug!("Waiting for remote model");
    let (_, receiver) = store.watch(model_card::ROOT_PATH, None, cancellation_token);
292
293
294
295
296

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

297
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
298
    let _endpoint_enabler_task = tokio::spawn(async move {
299
300
301
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
302
303
304
305
        }
    });

    // Pass the sender to the watcher
306
    let _watcher_task = tokio::spawn(async move {
307
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
308
    });
309

310
311
    Ok(())
}
312
313

/// Updates HTTP service endpoints based on available model types
314
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
315
316
317
318
319
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
320
        ModelUpdate::Added(card) => {
321
            // Handle all supported endpoint types, not just the first one
322
            for endpoint_type in card.model_type.as_endpoint_types() {
323
                service.enable_model_endpoint(endpoint_type, true);
324
            }
325
        }
326
        ModelUpdate::Removed(card) => {
327
            // Handle all supported endpoint types, not just the first one
328
            for endpoint_type in card.model_type.as_endpoint_types() {
329
                service.enable_model_endpoint(endpoint_type, false);
330
            }
331
        }
332
333
    }
}
334
335

/// Updates metrics for model type changes
336
fn update_model_metrics(
337
338
339
340
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
341
342
343
344
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
345
346
            }
        }
347
348
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
349
350
351
352
353
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}