http.rs 11.1 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    namespace::is_global_namespace,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
20
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
21
22

/// Build and run an HTTP service
23
24
25
26
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
50
51
52
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
53
54
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
55

56
57
58
59
60
61
62
63
64
    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Pass the custom backend metrics endpoint as-is (already in namespace.component.endpoint format)
    http_service_builder = http_service_builder.with_custom_backend_config(
        local_model
            .custom_backend_metrics_endpoint()
            .map(|s| s.to_string()),
        local_model.custom_backend_metrics_polling_interval(),
    );

65
    let http_service = match engine_config {
66
        EngineConfig::Dynamic(_) => {
67
            // This allows the /health endpoint to query store for active instances
68
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
69
            let http_service = http_service_builder.build()?;
70
71
72
73
74
75
76
77
78
79
80
81

            let router_config = engine_config.local_model().router_config();
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
            let namespace = engine_config.local_model().namespace().unwrap_or("");
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
82
                distributed_runtime.clone(),
83
84
85
86
87
88
89
90
91
                http_service.state().manager_clone(),
                router_config.router_mode,
                Some(router_config.kv_router_config),
                router_config.busy_threshold,
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
            )
            .await?;
92
            http_service
93
        }
94
        EngineConfig::StaticFull { engine, model, .. } => {
95
            let http_service = http_service_builder.build()?;
96
97
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
98
            let checksum = model.card().mdcsum();
99
100
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
101
102
103

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
104
                http_service.enable_model_endpoint(endpoint_type, true);
105
            }
106
            http_service
107
        }
108
109
        EngineConfig::StaticCore {
            engine: inner_engine,
110
            model,
111
            ..
112
        } => {
113
            let http_service = http_service_builder.build()?;
114
            let manager = http_service.model_manager();
115
            let checksum = model.card().mdcsum();
116

117
118
119
120
121
122
123
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
124
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
125

126
127
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
128
                NvCreateCompletionResponse,
129
            >(model.card(), inner_engine, tokenizer_hf)
130
            .await?;
131
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
132
133
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
134
                http_service.enable_model_endpoint(endpoint_type, true);
135
            }
136
            http_service
137
        }
138
    };
139
140
141
142
143
144
145
146
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Start custom backend metrics polling if configured
    let polling_task =
        if let (Some(namespace_component_endpoint), Some(polling_interval), Some(registry)) = (
            http_service
                .custom_backend_namespace_component_endpoint
                .as_ref(),
            http_service.custom_backend_metrics_polling_interval,
            http_service.custom_backend_registry.as_ref(),
        ) {
            tracing::info!(
                namespace_component_endpoint=%namespace_component_endpoint,
                polling_interval_secs=polling_interval,
                "Starting custom backend metrics polling task"
            );
            // Spawn the polling task and keep the JoinHandle alive so it can be aborted during
            // shutdown. While graceful shutdown is not strictly necessary for this non-critical
            // metrics polling, explicitly aborting it prevents the task from running during the
            // shutdown phase.
            Some(
                crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
169
                    distributed_runtime.clone(),
170
171
172
173
174
175
176
177
178
                    namespace_component_endpoint.clone(),
                    polling_interval,
                    registry.clone(),
                ),
            )
        } else {
            None
        };

179
180
181
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
182
183
184
185
186
187

    // Abort the polling task if it was started
    if let Some(task) = polling_task {
        task.abort();
    }

188
    distributed_runtime.shutdown(); // Cancel primary token
189
    Ok(())
190
}
191

192
/// Spawns a task that watches for new models in store,
193
/// and registers them with the ModelManager so that the HTTP service can use them.
194
#[allow(clippy::too_many_arguments)]
195
async fn run_watcher(
196
    runtime: DistributedRuntime,
197
    model_manager: Arc<ModelManager>,
198
    router_mode: RouterMode,
199
    kv_router_config: Option<KvRouterConfig>,
200
    busy_threshold: Option<f64>,
201
    target_namespace: Option<String>,
202
    http_service: Arc<HttpService>,
203
    metrics: Arc<crate::http::service::metrics::Metrics>,
204
) -> anyhow::Result<()> {
205
    let mut watch_obj = ModelWatcher::new(
206
        runtime.clone(),
207
208
209
210
211
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
212
    tracing::debug!("Waiting for remote model");
213
214
215
216
217
218
219
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
220
221
222
223
224

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

225
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
226
    let _endpoint_enabler_task = tokio::spawn(async move {
227
228
229
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
230
231
232
        }
    });

233
    // Pass the discovery stream to the watcher
234
    let _watcher_task = tokio::spawn(async move {
235
236
237
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
238
    });
239

240
241
    Ok(())
}
242
243

/// Updates HTTP service endpoints based on available model types
244
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
245
246
247
248
249
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
250
        ModelUpdate::Added(card) => {
251
            // Handle all supported endpoint types, not just the first one
252
            for endpoint_type in card.model_type.as_endpoint_types() {
253
                service.enable_model_endpoint(endpoint_type, true);
254
            }
255
        }
256
        ModelUpdate::Removed(card) => {
257
            // Handle all supported endpoint types, not just the first one
258
            for endpoint_type in card.model_type.as_endpoint_types() {
259
                service.enable_model_endpoint(endpoint_type, false);
260
            }
261
        }
262
263
    }
}
264
265

/// Updates metrics for model type changes
266
fn update_model_metrics(
267
268
269
270
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
271
272
273
274
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
275
276
            }
        }
277
278
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
279
280
281
282
283
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}