http.rs 9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, EngineFactoryCallback, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::is_global_namespace,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
20

/// Build and run an HTTP service
21
22
23
24
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
56
57
58
        EngineConfig::Dynamic {
            ref model,
            ref engine_factory,
        } => {
59
            // This allows the /health endpoint to query store for active instances
60
            http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
61
            let http_service = http_service_builder.build()?;
62

63
            let router_config = model.router_config();
64
65
66
            // Listen for models registering themselves, add them to HTTP service
            // Check if we should filter by namespace (based on the local model's namespace)
            // Get namespace from the model, fallback to endpoint_id namespace if not set
67
            let namespace = model.namespace().unwrap_or("");
68
69
70
71
72
73
            let target_namespace = if is_global_namespace(namespace) {
                None
            } else {
                Some(namespace.to_string())
            };
            run_watcher(
74
                distributed_runtime.clone(),
75
                http_service.state().manager_clone(),
76
                router_config.clone(),
77
78
79
                target_namespace,
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
80
                engine_factory.clone(),
81
82
            )
            .await?;
83
            http_service
84
        }
85
        EngineConfig::InProcessText { engine, model, .. } => {
86
            let http_service = http_service_builder.build()?;
87
88
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
89
            let checksum = model.card().mdcsum();
90
91
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
92
93
94

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
95
                http_service.enable_model_endpoint(endpoint_type, true);
96
            }
97
            http_service
98
        }
99
        EngineConfig::InProcessTokens {
100
            engine: inner_engine,
101
            model,
102
            ..
103
        } => {
104
            let http_service = http_service_builder.build()?;
105
            let manager = http_service.model_manager();
106
            let checksum = model.card().mdcsum();
107

108
109
110
111
112
113
114
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
115
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
116

117
118
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
119
                NvCreateCompletionResponse,
120
            >(model.card(), inner_engine, tokenizer_hf)
121
            .await?;
122
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
123
124
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
125
                http_service.enable_model_endpoint(endpoint_type, true);
126
            }
127
            http_service
128
        }
129
    };
130
131
132
133
134
135
136
137
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
138

139
140
141
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
142

143
    distributed_runtime.shutdown(); // Cancel primary token
144
    Ok(())
145
}
146

147
/// Spawns a task that watches for new models in store,
148
149
/// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher(
150
    runtime: DistributedRuntime,
151
    model_manager: Arc<ModelManager>,
152
    router_config: RouterConfig,
153
    target_namespace: Option<String>,
154
    http_service: Arc<HttpService>,
155
    metrics: Arc<crate::http::service::metrics::Metrics>,
156
    engine_factory: Option<EngineFactoryCallback>,
157
) -> anyhow::Result<()> {
158
159
160
161
162
    let mut watch_obj = ModelWatcher::new(
        runtime.clone(),
        model_manager,
        router_config,
        engine_factory,
163
        metrics.clone(),
164
    );
165
    tracing::debug!("Waiting for remote model");
166
167
168
169
170
171
172
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
173
174
175
176
177

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

178
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
179
    let _endpoint_enabler_task = tokio::spawn(async move {
180
181
182
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
183
184
185
        }
    });

186
    // Pass the discovery stream to the watcher
187
    let _watcher_task = tokio::spawn(async move {
188
189
190
        watch_obj
            .watch(discovery_stream, target_namespace.as_deref())
            .await;
191
    });
192

193
194
    Ok(())
}
195
196

/// Updates HTTP service endpoints based on available model types
197
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
198
199
200
201
202
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
203
        ModelUpdate::Added(card) => {
204
            // Handle all supported endpoint types, not just the first one
205
            for endpoint_type in card.model_type.as_endpoint_types() {
206
                service.enable_model_endpoint(endpoint_type, true);
207
            }
208
        }
209
        ModelUpdate::Removed(card) => {
210
            // Handle all supported endpoint types, not just the first one
211
            for endpoint_type in card.model_type.as_endpoint_types() {
212
                service.enable_model_endpoint(endpoint_type, false);
213
            }
214
        }
215
216
    }
}
217
218

/// Updates metrics for model type changes
219
fn update_model_metrics(
220
221
222
223
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
224
225
226
227
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
228
229
            }
        }
230
231
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
232
233
234
235
236
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}