http.rs 9.24 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::is_global_namespace,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
20

/// Build and run an HTTP service
21
22
23
24
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
56
        EngineConfig::Dynamic {
            ref model,
57
            ref chat_engine_factory,
58
        } => {
59
60
61
            // Pass the discovery client so the /health endpoint can query active instances
            http_service_builder =
                http_service_builder.discovery(Some(distributed_runtime.discovery()));
62
            let http_service = http_service_builder.build()?;
63

64
            let router_config = model.router_config();
65
            let migration_limit = model.migration_limit();
66
67
68
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
69
            let namespace = model.namespace().unwrap_or("");
70
71
72
73
74
75
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
76
                distributed_runtime.clone(),
77
                http_service.state().manager_clone(),
78
                router_config.clone(),
79
                migration_limit,
80
81
82
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
83
                chat_engine_factory.clone(),
84
85
            )
            .await?;
86
            http_service
87
        }
88
        EngineConfig::InProcessText { engine, model, .. } => {
89
            let http_service = http_service_builder.build()?;
90
91
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
92
            let checksum = model.card().mdcsum();
93
94
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
95
96
97

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
98
                http_service.enable_model_endpoint(endpoint_type, true);
99
            }
100
            http_service
101
        }
102
        EngineConfig::InProcessTokens {
103
            engine: inner_engine,
104
            model,
105
            ..
106
        } => {
107
            let http_service = http_service_builder.build()?;
108
            let manager = http_service.model_manager();
109
            let checksum = model.card().mdcsum();
110

111
112
113
114
115
116
117
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
118
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
119

120
121
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
122
                NvCreateCompletionResponse,
123
            >(model.card(), inner_engine, tokenizer_hf)
124
            .await?;
125
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
126
127
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
128
                http_service.enable_model_endpoint(endpoint_type, true);
129
            }
130
            http_service
131
        }
132
    };
133
134
135
136
137
138
139
140
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
141

142
143
144
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
145

146
    distributed_runtime.shutdown(); // Cancel primary token
147
    Ok(())
148
}
149

150
/// Spawns a task that watches for new models in store,
151
/// and registers them with the ModelManager so that the HTTP service can use them.
152
#[allow(clippy::too_many_arguments)]
153
async fn run_watcher(
154
    runtime: DistributedRuntime,
155
    model_manager: Arc<ModelManager>,
156
    router_config: RouterConfig,
157
    migration_limit: u32,
158
    target_namespace: Option<String>,
159
    http_service: Arc<HttpService>,
160
    metrics: Arc<crate::http::service::metrics::Metrics>,
161
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
162
) -> anyhow::Result<()> {
163
164
165
166
    let mut watch_obj = ModelWatcher::new(
        runtime.clone(),
        model_manager,
        router_config,
167
        migration_limit,
168
        chat_engine_factory,
169
        metrics.clone(),
170
    );
171
    tracing::debug!("Waiting for remote model");
172
173
174
175
176
177
178
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
179
180
181
182
183

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

184
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
185
    let _endpoint_enabler_task = tokio::spawn(async move {
186
187
188
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
189
190
191
        }
    });

192
    // Pass the discovery stream to the watcher
193
    let _watcher_task = tokio::spawn(async move {
194
195
196
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
197
    });
198

199
200
    Ok(())
}
201
202

/// Updates HTTP service endpoints based on available model types
203
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
204
205
206
207
208
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
209
        ModelUpdate::Added(card) => {
210
            // Handle all supported endpoint types, not just the first one
211
            for endpoint_type in card.model_type.as_endpoint_types() {
212
                service.enable_model_endpoint(endpoint_type, true);
213
            }
214
        }
215
        ModelUpdate::Removed(card) => {
216
            // Handle all supported endpoint types, not just the first one
217
            for endpoint_type in card.model_type.as_endpoint_types() {
218
                service.enable_model_endpoint(endpoint_type, false);
219
            }
220
        }
221
222
    }
}
223
224

/// Updates metrics for model type changes
225
fn update_model_metrics(
226
227
228
229
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
230
231
232
233
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
234
235
            }
        }
236
237
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
238
239
240
241
242
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}