"deploy/vscode:/vscode.git/clone" did not exist on "9572355fc1499fe6610083235637d4709527de7f"
http.rs 10.7 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::is_global_namespace,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
20

/// Build and run an HTTP service
21
22
23
24
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
55
56
57
58
59
60
61
62
    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Pass the custom backend metrics endpoint as-is (already in namespace.component.endpoint format)
    http_service_builder = http_service_builder.with_custom_backend_config(
        local_model
            .custom_backend_metrics_endpoint()
            .map(|s| s.to_string()),
        local_model.custom_backend_metrics_polling_interval(),
    );

63
    let http_service = match engine_config {
64
        EngineConfig::Dynamic(_) => {
65
            // This allows the /health endpoint to query store for active instances
66
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
67
            let http_service = http_service_builder.build()?;
68
69
70
71
72
73
74
75
76
77
78
79

            let router_config = engine_config.local_model().router_config();
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
            let namespace = engine_config.local_model().namespace().unwrap_or("");
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
80
                distributed_runtime.clone(),
81
                http_service.state().manager_clone(),
82
                router_config.clone(),
83
84
85
86
87
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
            )
            .await?;
88
            http_service
89
        }
90
        EngineConfig::StaticFull { engine, model, .. } => {
91
            let http_service = http_service_builder.build()?;
92
93
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
94
            let checksum = model.card().mdcsum();
95
96
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
97
98
99

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
100
                http_service.enable_model_endpoint(endpoint_type, true);
101
            }
102
            http_service
103
        }
104
105
        EngineConfig::StaticCore {
            engine: inner_engine,
106
            model,
107
            ..
108
        } => {
109
            let http_service = http_service_builder.build()?;
110
            let manager = http_service.model_manager();
111
            let checksum = model.card().mdcsum();
112

113
114
115
116
117
118
119
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
120
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
121

122
123
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
124
                NvCreateCompletionResponse,
125
            >(model.card(), inner_engine, tokenizer_hf)
126
            .await?;
127
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
128
129
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
130
                http_service.enable_model_endpoint(endpoint_type, true);
131
            }
132
            http_service
133
        }
134
    };
135
136
137
138
139
140
141
142
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    // DEPRECATED: To be removed after custom backends migrate to Dynamo backend.
    // Start custom backend metrics polling if configured
    let polling_task =
        if let (Some(namespace_component_endpoint), Some(polling_interval), Some(registry)) = (
            http_service
                .custom_backend_namespace_component_endpoint
                .as_ref(),
            http_service.custom_backend_metrics_polling_interval,
            http_service.custom_backend_registry.as_ref(),
        ) {
            tracing::info!(
                namespace_component_endpoint=%namespace_component_endpoint,
                polling_interval_secs=polling_interval,
                "Starting custom backend metrics polling task"
            );
            // Spawn the polling task and keep the JoinHandle alive so it can be aborted during
            // shutdown. While graceful shutdown is not strictly necessary for this non-critical
            // metrics polling, explicitly aborting it prevents the task from running during the
            // shutdown phase.
            Some(
                crate::http::service::custom_backend_metrics::spawn_custom_backend_polling_task(
165
                    distributed_runtime.clone(),
166
167
168
169
170
171
172
173
174
                    namespace_component_endpoint.clone(),
                    polling_interval,
                    registry.clone(),
                ),
            )
        } else {
            None
        };

175
176
177
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
178
179
180
181
182
183

    // Abort the polling task if it was started
    if let Some(task) = polling_task {
        task.abort();
    }

184
    distributed_runtime.shutdown(); // Cancel primary token
185
    Ok(())
186
}
187

188
/// Spawns a task that watches for new models in store,
189
190
/// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher(
191
    runtime: DistributedRuntime,
192
    model_manager: Arc<ModelManager>,
193
    router_config: RouterConfig,
194
    target_namespace: Option<String>,
195
    http_service: Arc<HttpService>,
196
    metrics: Arc<crate::http::service::metrics::Metrics>,
197
) -> anyhow::Result<()> {
198
    let mut watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config);
199
    tracing::debug!("Waiting for remote model");
200
201
202
203
204
205
206
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
207
208
209
210
211

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

212
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
213
    let _endpoint_enabler_task = tokio::spawn(async move {
214
215
216
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
217
218
219
        }
    });

220
    // Pass the discovery stream to the watcher
221
    let _watcher_task = tokio::spawn(async move {
222
223
224
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
225
    });
226

227
228
    Ok(())
}
229
230

/// Updates HTTP service endpoints based on available model types
231
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
232
233
234
235
236
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
237
        ModelUpdate::Added(card) => {
238
            // Handle all supported endpoint types, not just the first one
239
            for endpoint_type in card.model_type.as_endpoint_types() {
240
                service.enable_model_endpoint(endpoint_type, true);
241
            }
242
        }
243
        ModelUpdate::Removed(card) => {
244
            // Handle all supported endpoint types, not just the first one
245
            for endpoint_type in card.model_type.as_endpoint_types() {
246
                service.enable_model_endpoint(endpoint_type, false);
247
            }
248
        }
249
250
    }
}
251
252

/// Updates metrics for model type changes
253
fn update_model_metrics(
254
255
256
257
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
258
259
260
261
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
262
263
            }
        }
264
265
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
266
267
268
269
270
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}