http.rs 9.18 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, EngineFactoryCallback, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::is_global_namespace,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
20

/// Build and run an HTTP service
21
22
23
24
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
56
57
58
        EngineConfig::Dynamic {
            ref model,
            ref engine_factory,
        } => {
59
            // This allows the /health endpoint to query store for active instances
60
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
61
            let http_service = http_service_builder.build()?;
62

63
            let router_config = model.router_config();
64
            let migration_limit = model.migration_limit();
65
66
67
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
68
            let namespace = model.namespace().unwrap_or("");
69
70
71
72
73
74
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
75
                distributed_runtime.clone(),
76
                http_service.state().manager_clone(),
77
                router_config.clone(),
78
                migration_limit,
79
80
81
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
82
                engine_factory.clone(),
83
84
            )
            .await?;
85
            http_service
86
        }
87
        EngineConfig::InProcessText { engine, model, .. } => {
88
            let http_service = http_service_builder.build()?;
89
90
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
91
            let checksum = model.card().mdcsum();
92
93
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
94
95
96

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
97
                http_service.enable_model_endpoint(endpoint_type, true);
98
            }
99
            http_service
100
        }
101
        EngineConfig::InProcessTokens {
102
            engine: inner_engine,
103
            model,
104
            ..
105
        } => {
106
            let http_service = http_service_builder.build()?;
107
            let manager = http_service.model_manager();
108
            let checksum = model.card().mdcsum();
109

110
111
112
113
114
115
116
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
117
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
118

119
120
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
121
                NvCreateCompletionResponse,
122
            >(model.card(), inner_engine, tokenizer_hf)
123
            .await?;
124
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
125
126
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
127
                http_service.enable_model_endpoint(endpoint_type, true);
128
            }
129
            http_service
130
        }
131
    };
132
133
134
135
136
137
138
139
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
140

141
142
143
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
144

145
    distributed_runtime.shutdown(); // Cancel primary token
146
    Ok(())
147
}
148

149
/// Spawns a task that watches for new models in store,
150
/// and registers them with the ModelManager so that the HTTP service can use them.
151
#[allow(clippy::too_many_arguments)]
152
async fn run_watcher(
153
    runtime: DistributedRuntime,
154
    model_manager: Arc<ModelManager>,
155
    router_config: RouterConfig,
156
    migration_limit: u32,
157
    target_namespace: Option<String>,
158
    http_service: Arc<HttpService>,
159
    metrics: Arc<crate::http::service::metrics::Metrics>,
160
    engine_factory: Option<EngineFactoryCallback>,
161
) -> anyhow::Result<()> {
162
163
164
165
    let mut watch_obj = ModelWatcher::new(
        runtime.clone(),
        model_manager,
        router_config,
166
        migration_limit,
167
        engine_factory,
168
        metrics.clone(),
169
    );
170
    tracing::debug!("Waiting for remote model");
171
172
173
174
175
176
177
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
178
179
180
181
182

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

183
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
184
    let _endpoint_enabler_task = tokio::spawn(async move {
185
186
187
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
188
189
190
        }
    });

191
    // Pass the discovery stream to the watcher
192
    let _watcher_task = tokio::spawn(async move {
193
194
195
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
196
    });
197

198
199
    Ok(())
}
200
201

/// Updates HTTP service endpoints based on available model types
202
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
203
204
205
206
207
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
208
        ModelUpdate::Added(card) => {
209
            // Handle all supported endpoint types, not just the first one
210
            for endpoint_type in card.model_type.as_endpoint_types() {
211
                service.enable_model_endpoint(endpoint_type, true);
212
            }
213
        }
214
        ModelUpdate::Removed(card) => {
215
            // Handle all supported endpoint types, not just the first one
216
            for endpoint_type in card.model_type.as_endpoint_types() {
217
                service.enable_model_endpoint(endpoint_type, false);
218
            }
219
        }
220
221
    }
}
222
223

/// Updates metrics for model type changes
224
fn update_model_metrics(
225
226
227
228
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
229
230
231
232
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
233
234
            }
        }
235
236
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
237
238
239
240
241
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}