http.rs 9.35 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::is_global_namespace,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
20

/// Build and run an HTTP service
21
22
23
24
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
51
52
    http_service_builder =
        http_service_builder.cancel_token(Some(distributed_runtime.primary_token()));
Graham King's avatar
Graham King committed
53
54
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
55

56
    let http_service = match engine_config {
57
58
        EngineConfig::Dynamic {
            ref model,
59
            ref chat_engine_factory,
60
        } => {
61
62
63
            // Pass the discovery client so the /health endpoint can query active instances
            http_service_builder =
                http_service_builder.discovery(Some(distributed_runtime.discovery()));
64
            let http_service = http_service_builder.build()?;
65

66
            let router_config = model.router_config();
67
            let migration_limit = model.migration_limit();
68
69
70
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
71
            let namespace = model.namespace().unwrap_or("");
72
73
74
75
76
77
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
78
                distributed_runtime.clone(),
79
                http_service.state().manager_clone(),
80
                router_config.clone(),
81
                migration_limit,
82
83
84
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
85
                chat_engine_factory.clone(),
86
87
            )
            .await?;
88
            http_service
89
        }
90
        EngineConfig::InProcessText { engine, model, .. } => {
91
            let http_service = http_service_builder.build()?;
92
93
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
94
            let checksum = model.card().mdcsum();
95
96
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
97
98
99

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
100
                http_service.enable_model_endpoint(endpoint_type, true);
101
            }
102
            http_service
103
        }
104
        EngineConfig::InProcessTokens {
105
            engine: inner_engine,
106
            model,
107
            ..
108
        } => {
109
            let http_service = http_service_builder.build()?;
110
            let manager = http_service.model_manager();
111
            let checksum = model.card().mdcsum();
112

113
114
115
116
117
118
119
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
120
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
121

122
123
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
124
                NvCreateCompletionResponse,
125
            >(model.card(), inner_engine, tokenizer_hf)
126
            .await?;
127
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
128
129
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
130
                http_service.enable_model_endpoint(endpoint_type, true);
131
            }
132
            http_service
133
        }
134
    };
135
136
137
138
139
140
141
142
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
143

144
145
146
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
147

148
    distributed_runtime.shutdown(); // Cancel primary token
149
    Ok(())
150
}
151

152
/// Spawns a task that watches for new models in store,
153
/// and registers them with the ModelManager so that the HTTP service can use them.
154
#[allow(clippy::too_many_arguments)]
155
async fn run_watcher(
156
    runtime: DistributedRuntime,
157
    model_manager: Arc<ModelManager>,
158
    router_config: RouterConfig,
159
    migration_limit: u32,
160
    target_namespace: Option<String>,
161
    http_service: Arc<HttpService>,
162
    metrics: Arc<crate::http::service::metrics::Metrics>,
163
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
164
) -> anyhow::Result<()> {
165
166
167
168
    let mut watch_obj = ModelWatcher::new(
        runtime.clone(),
        model_manager,
        router_config,
169
        migration_limit,
170
        chat_engine_factory,
171
        metrics.clone(),
172
    );
173
    tracing::debug!("Waiting for remote model");
174
175
176
177
178
179
180
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
181
182
183
184
185

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

186
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
187
    let _endpoint_enabler_task = tokio::spawn(async move {
188
189
190
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
191
192
193
        }
    });

194
    // Pass the discovery stream to the watcher
195
    let _watcher_task = tokio::spawn(async move {
196
197
198
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
199
    });
200

201
202
    Ok(())
}
203
204

/// Updates HTTP service endpoints based on available model types
205
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
206
207
208
209
210
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
211
        ModelUpdate::Added(card) => {
212
            // Handle all supported endpoint types, not just the first one
213
            for endpoint_type in card.model_type.as_endpoint_types() {
214
                service.enable_model_endpoint(endpoint_type, true);
215
            }
216
        }
217
        ModelUpdate::Removed(card) => {
218
            // Handle all supported endpoint types, not just the first one
219
            for endpoint_type in card.model_type.as_endpoint_types() {
220
                service.enable_model_endpoint(endpoint_type, false);
221
            }
222
        }
223
224
    }
}
225
226

/// Updates metrics for model type changes
227
fn update_model_metrics(
228
229
230
231
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
232
233
234
235
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
236
237
            }
        }
238
239
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
240
241
242
243
244
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}