http.rs 11 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    namespace::is_global_namespace,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
use dynamo_runtime::transports::etcd;
20
use dynamo_runtime::{DistributedRuntime, Runtime};
21
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
22
23

/// Build and run an HTTP service
24
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
        EngineConfig::Dynamic(_) => {
56
57
58
59
60
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
61
62
            match etcd_client {
                Some(ref etcd_client) => {
63
                    let router_config = engine_config.local_model().router_config();
64
                    // Listen for models registering themselves in etcd, add them to HTTP service
65
66
67
68
69
70
71
72
                    // Check if we should filter by namespace (based on the local model's namespace)
                    // Get namespace from the model, fallback to endpoint_id namespace if not set
                    let namespace = engine_config.local_model().namespace().unwrap_or("");
                    let target_namespace = if is_global_namespace(namespace) {
                        None
                    } else {
                        Some(namespace.to_string())
                    };
73
                    run_watcher(
74
                        distributed_runtime,
75
                        http_service.state().manager_clone(),
76
                        etcd_client.clone(),
77
                        MODEL_ROOT_PATH,
78
                        router_config.router_mode,
79
                        Some(router_config.kv_router_config),
80
                        router_config.busy_threshold,
81
                        target_namespace,
82
                        Arc::new(http_service.clone()),
83
84
                    )
                    .await?;
85
86
87
88
89
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
90
            http_service
91
        }
92
93
94
95
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

96
            let dst_config = DistributedConfig::from_settings(true); // true means static
97
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
98
            let http_service = http_service_builder.build()?;
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

122
            let tokenizer_hf = card.tokenizer_hf()?;
123
124
125
            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
126
127
128
129
130
131
132
133
            >(
                card,
                &client,
                router_mode,
                None,
                kv_chooser.clone(),
                tokenizer_hf.clone(),
            )
134
135
136
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

137
138
139
140
141
142
            let completions_engine =
                entrypoint::build_routed_pipeline::<
                    NvCreateCompletionRequest,
                    NvCreateCompletionResponse,
                >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
                .await?;
143
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
144

145
            for endpoint_type in EndpointType::all() {
146
                http_service.enable_model_endpoint(endpoint_type, true);
147
148
            }

149
            http_service
150
151
        }
        EngineConfig::StaticFull { engine, model, .. } => {
152
            let http_service = http_service_builder.build()?;
153
154
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
155
156
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
157
158
159

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
160
                http_service.enable_model_endpoint(endpoint_type, true);
161
            }
162
            http_service
163
        }
164
165
        EngineConfig::StaticCore {
            engine: inner_engine,
166
            model,
167
            ..
168
        } => {
169
            let http_service = http_service_builder.build()?;
170
171
            let manager = http_service.model_manager();

172
173
174
175
176
177
178
            let tokenizer_hf = model.card().tokenizer_hf()?;
            let chat_pipeline =
                common::build_pipeline::<
                    NvCreateChatCompletionRequest,
                    NvCreateChatCompletionStreamResponse,
                >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
                .await?;
179
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
180

181
182
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
183
                NvCreateCompletionResponse,
184
            >(model.card(), inner_engine, tokenizer_hf)
185
            .await?;
186
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
187
188
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
189
                http_service.enable_model_endpoint(endpoint_type, true);
190
            }
191
            http_service
192
        }
193
    };
194
195
196
197
198
199
200
201
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
202
203
204
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
205
}
206
207
208

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
209
#[allow(clippy::too_many_arguments)]
210
async fn run_watcher(
211
    runtime: DistributedRuntime,
212
    model_manager: Arc<ModelManager>,
213
214
    etcd_client: etcd::Client,
    network_prefix: &str,
215
    router_mode: RouterMode,
216
    kv_router_config: Option<KvRouterConfig>,
217
    busy_threshold: Option<f64>,
218
    target_namespace: Option<String>,
219
    http_service: Arc<HttpService>,
220
) -> anyhow::Result<()> {
221
222
223
224
225
226
227
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
228
229
230
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
231
232
233
234
235
236
237
238
239
240

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

    // Spawn a task to watch for model type changes and update HTTP service endpoints
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
241
            update_http_endpoints(http_service.clone(), model_type);
242
243
244
245
        }
    });

    // Pass the sender to the watcher
246
    let _watcher_task = tokio::spawn(async move {
247
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
248
    });
249

250
251
    Ok(())
}
252
253

/// Updates HTTP service endpoints based on available model types
254
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
255
256
257
258
259
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
260
261
262
263
        ModelUpdate::Added(model_type) => {
            // Handle all supported endpoint types, not just the first one
            for endpoint_type in model_type.as_endpoint_types() {
                service.enable_model_endpoint(endpoint_type, true);
264
            }
265
266
267
268
269
        }
        ModelUpdate::Removed(model_type) => {
            // Handle all supported endpoint types, not just the first one
            for endpoint_type in model_type.as_endpoint_types() {
                service.enable_model_endpoint(endpoint_type, false);
270
            }
271
        }
272
273
    }
}