http.rs 12.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_card,
14
    namespace::is_global_namespace,
15
16
17
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
18
19
    },
};
20
use dynamo_runtime::transports::etcd;
21
use dynamo_runtime::{DistributedRuntime, Runtime};
22
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
23
24

/// Build and run an HTTP service
25
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
49
50
51
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
52
53
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
54

55
    let http_service = match engine_config {
56
        EngineConfig::Dynamic(_) => {
57
58
59
60
61
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
62
63
            match etcd_client {
                Some(ref etcd_client) => {
64
                    let router_config = engine_config.local_model().router_config();
65
                    // Listen for models registering themselves in etcd, add them to HTTP service
66
67
68
69
70
71
72
73
                    // Check if we should filter by namespace (based on the local model's namespace)
                    // Get namespace from the model, fallback to endpoint_id namespace if not set
                    let namespace = engine_config.local_model().namespace().unwrap_or("");
                    let target_namespace = if is_global_namespace(namespace) {
                        None
                    } else {
                        Some(namespace.to_string())
                    };
74
                    run_watcher(
75
                        distributed_runtime,
76
                        http_service.state().manager_clone(),
77
                        etcd_client.clone(),
78
                        model_card::ROOT_PATH,
79
                        router_config.router_mode,
80
                        Some(router_config.kv_router_config),
81
                        router_config.busy_threshold,
82
                        target_namespace,
83
                        Arc::new(http_service.clone()),
84
                        http_service.state().metrics_clone(),
85
86
                    )
                    .await?;
87
88
89
90
91
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
92
            http_service
93
        }
94
95
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
96
97
            let checksum = card.mdcsum();

98
99
            let router_mode = local_model.router_config().router_mode;

100
            let dst_config = DistributedConfig::from_settings(true); // true means static
101
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
102
            let http_service = http_service_builder.build()?;
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

126
            let tokenizer_hf = card.tokenizer_hf()?;
127
128
129
            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
130
131
132
133
134
135
136
137
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser.clone(),
                tokenizer_hf.clone(),
            )
138
            .await?;
139
140
141
142
143
            manager.add_chat_completions_model(
                local_model.display_name(),
                checksum,
                chat_engine,
            )?;
144

145
146
147
148
149
150
            let completions_engine =
                entrypoint::build_routed_pipeline::<
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
                .await?;
151
152
153
154
155
            manager.add_completions_model(
                local_model.display_name(),
                checksum,
                completions_engine,
            )?;
156

157
            for endpoint_type in EndpointType::all() {
158
                http_service.enable_model_endpoint(endpoint_type, true);
159
160
            }

161
            http_service
162
163
        }
        EngineConfig::StaticFull { engine, model, .. } => {
164
            let http_service = http_service_builder.build()?;
165
166
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
167
168
169
            let checksum = model.card().mdcsum();
            manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
170
171
172

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
173
                http_service.enable_model_endpoint(endpoint_type, true);
174
            }
175
            http_service
176
        }
177
178
        EngineConfig::StaticCore {
            engine: inner_engine,
179
            model,
180
            ..
181
        } => {
182
            let http_service = http_service_builder.build()?;
183
            let manager = http_service.model_manager();
184
            let checksum = model.card().mdcsum();
185

186
187
188
189
190
191
192
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
193
            manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
194

195
196
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
197
                NvCreateCompletionResponse,
198
            >(model.card(), inner_engine, tokenizer_hf)
199
            .await?;
200
            manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
201
202
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
203
                http_service.enable_model_endpoint(endpoint_type, true);
204
            }
205
            http_service
206
        }
207
    };
208
209
210
211
212
213
214
215
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
216
217
218
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
219
}
220
221
222

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
223
#[allow(clippy::too_many_arguments)]
224
async fn run_watcher(
225
    runtime: DistributedRuntime,
226
    model_manager: Arc<ModelManager>,
227
228
    etcd_client: etcd::Client,
    network_prefix: &str,
229
    router_mode: RouterMode,
230
    kv_router_config: Option<KvRouterConfig>,
231
    busy_threshold: Option<f64>,
232
    target_namespace: Option<String>,
233
    http_service: Arc<HttpService>,
234
    metrics: Arc<crate::http::service::metrics::Metrics>,
235
) -> anyhow::Result<()> {
236
237
238
239
240
241
242
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
243
244
245
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
246
247
248
249
250
251

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

252
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
253
    let _endpoint_enabler_task = tokio::spawn(async move {
254
255
256
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
257
258
259
260
        }
    });

    // Pass the sender to the watcher
261
    let _watcher_task = tokio::spawn(async move {
262
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
263
    });
264

265
266
    Ok(())
}
267
268

/// Updates HTTP service endpoints based on available model types
269
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
270
271
272
273
274
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
275
        ModelUpdate::Added(card) => {
276
            // Handle all supported endpoint types, not just the first one
277
            for endpoint_type in card.model_type.as_endpoint_types() {
278
                service.enable_model_endpoint(endpoint_type, true);
279
            }
280
        }
281
        ModelUpdate::Removed(card) => {
282
            // Handle all supported endpoint types, not just the first one
283
            for endpoint_type in card.model_type.as_endpoint_types() {
284
                service.enable_model_endpoint(endpoint_type, false);
285
            }
286
        }
287
288
    }
}
289
290

/// Updates metrics for model type changes
291
fn update_model_metrics(
292
293
294
295
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
296
297
298
299
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
300
301
            }
        }
302
303
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
304
305
306
307
308
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}