http.rs 13.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    namespace::is_global_namespace,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
use dynamo_runtime::transports::etcd;
20
use dynamo_runtime::{DistributedRuntime, Runtime};
21
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
22
23

/// Build and run an HTTP service
24
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
        EngineConfig::Dynamic(_) => {
56
57
58
59
60
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
61
62
            match etcd_client {
                Some(ref etcd_client) => {
63
                    let router_config = engine_config.local_model().router_config();
64
                    // Listen for models registering themselves in etcd, add them to HTTP service
65
66
67
68
69
70
71
72
                    // Check if we should filter by namespace (based on the local model's namespace)
                    // Get namespace from the model, fallback to endpoint_id namespace if not set
                    let namespace = engine_config.local_model().namespace().unwrap_or("");
                    let target_namespace = if is_global_namespace(namespace) {
                        None
                    } else {
                        Some(namespace.to_string())
                    };
73
                    run_watcher(
74
                        distributed_runtime,
75
                        http_service.state().manager_clone(),
76
                        etcd_client.clone(),
77
                        MODEL_ROOT_PATH,
78
                        router_config.router_mode,
79
                        Some(router_config.kv_router_config),
80
                        router_config.busy_threshold,
81
                        target_namespace,
82
                        Arc::new(http_service.clone()),
83
                        http_service.state().metrics_clone(),
84
85
                    )
                    .await?;
86
87
88
89
90
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
91
            http_service
92
        }
93
94
95
96
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

97
            let dst_config = DistributedConfig::from_settings(true); // true means static
98
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
99
            let http_service = http_service_builder.build()?;
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

123
            let tokenizer_hf = card.tokenizer_hf()?;
124
125
126
            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
127
128
129
130
131
132
133
134
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser.clone(),
                tokenizer_hf.clone(),
            )
135
136
137
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

138
139
140
141
142
143
            let completions_engine =
                entrypoint::build_routed_pipeline::<
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
                .await?;
144
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
145

146
            for endpoint_type in EndpointType::all() {
147
                http_service.enable_model_endpoint(endpoint_type, true);
148
149
            }

150
            http_service
151
152
        }
        EngineConfig::StaticFull { engine, model, .. } => {
153
            let http_service = http_service_builder.build()?;
154
155
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
156
157
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
158
159
160

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
161
                http_service.enable_model_endpoint(endpoint_type, true);
162
            }
163
            http_service
164
        }
165
166
        EngineConfig::StaticCore {
            engine: inner_engine,
167
            model,
168
            ..
169
        } => {
170
            let http_service = http_service_builder.build()?;
171
172
            let manager = http_service.model_manager();

173
174
175
176
177
178
179
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
180
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
181

182
183
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
184
                NvCreateCompletionResponse,
185
            >(model.card(), inner_engine, tokenizer_hf)
186
            .await?;
187
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
188
189
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
190
                http_service.enable_model_endpoint(endpoint_type, true);
191
            }
192
            http_service
193
        }
194
    };
195
196
197
198
199
200
201
202
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
203
204
205
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
206
}
207
208
209

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
210
#[allow(clippy::too_many_arguments)]
211
async fn run_watcher(
212
    runtime: DistributedRuntime,
213
    model_manager: Arc<ModelManager>,
214
215
    etcd_client: etcd::Client,
    network_prefix: &str,
216
    router_mode: RouterMode,
217
    kv_router_config: Option<KvRouterConfig>,
218
    busy_threshold: Option<f64>,
219
    target_namespace: Option<String>,
220
    http_service: Arc<HttpService>,
221
    metrics: Arc<crate::http::service::metrics::Metrics>,
222
) -> anyhow::Result<()> {
223
224
225
    // Clone model_manager before it's moved into ModelWatcher
    let model_manager_clone = model_manager.clone();

226
227
228
229
230
231
232
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
233
234
235
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
236
237
238
239
240
241

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

242
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
243
244
245
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
246
247

            // Update HTTP endpoints (existing functionality)
248
            update_http_endpoints(http_service.clone(), model_type);
249
250
251
252
253
254
255
256
257

            // Update metrics (only for added models)
            update_model_metrics(
                model_type,
                model_manager_clone.clone(),
                metrics.clone(),
                Some(etcd_client.clone()),
            )
            .await;
258
259
260
261
        }
    });

    // Pass the sender to the watcher
262
    let _watcher_task = tokio::spawn(async move {
263
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
264
    });
265

266
267
    Ok(())
}
268
269

/// Updates HTTP service endpoints based on available model types
270
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
271
272
273
274
275
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
276
277
278
279
        ModelUpdate::Added(model_type) => {
            // Handle all supported endpoint types, not just the first one
            for endpoint_type in model_type.as_endpoint_types() {
                service.enable_model_endpoint(endpoint_type, true);
280
            }
281
282
283
284
285
        }
        ModelUpdate::Removed(model_type) => {
            // Handle all supported endpoint types, not just the first one
            for endpoint_type in model_type.as_endpoint_types() {
                service.enable_model_endpoint(endpoint_type, false);
286
            }
287
        }
288
289
    }
}
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

/// Updates metrics for model type changes
async fn update_model_metrics(
    model_type: ModelUpdate,
    model_manager: Arc<ModelManager>,
    metrics: Arc<crate::http::service::metrics::Metrics>,
    etcd_client: Option<etcd::Client>,
) {
    match model_type {
        ModelUpdate::Added(model_type) => {
            tracing::debug!("Updating metrics for added model type: {:?}", model_type);

            // Get all model entries and update metrics for matching types
            let model_entries = model_manager.get_model_entries();
            for entry in model_entries {
                if entry.model_type == model_type {
                    // Update runtime config metrics if available
                    if let Some(runtime_config) = &entry.runtime_config {
                        metrics.update_runtime_config_metrics(&entry.name, runtime_config);
                    }

                    // Update MDC metrics if etcd is available
                    if let Some(ref etcd) = etcd_client
                        && let Err(e) = metrics
                            .update_metrics_from_model_entry_with_mdc(&entry, etcd)
                            .await
                    {
                        tracing::debug!(
                            model = %entry.name,
                            error = %e,
                            "Failed to update MDC metrics for newly added model"
                        );
                    }
                }
            }
        }
        ModelUpdate::Removed(model_type) => {
            tracing::debug!("Model type removed: {:?}", model_type);
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}