http.rs 10.9 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_type::ModelType,
14
    namespace::is_global_namespace,
15
16
17
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
18
19
    },
};
20
use dynamo_runtime::transports::etcd;
21
use dynamo_runtime::{DistributedRuntime, Runtime};
22
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
23
24

/// Build and run an HTTP service
25
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
49
50
51
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
52
53
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
54

55
    let http_service = match engine_config {
56
        EngineConfig::Dynamic(_) => {
57
58
59
60
61
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
62
63
            match etcd_client {
                Some(ref etcd_client) => {
64
                    let router_config = engine_config.local_model().router_config();
65
                    // Listen for models registering themselves in etcd, add them to HTTP service
66
67
68
69
70
71
72
73
                    // Check if we should filter by namespace (based on the local model's namespace)
                    // Get namespace from the model, fallback to endpoint_id namespace if not set
                    let namespace = engine_config.local_model().namespace().unwrap_or("");
                    let target_namespace = if is_global_namespace(namespace) {
                        None
                    } else {
                        Some(namespace.to_string())
                    };
74
                    run_watcher(
75
                        distributed_runtime,
76
                        http_service.state().manager_clone(),
77
                        etcd_client.clone(),
78
                        MODEL_ROOT_PATH,
79
                        router_config.router_mode,
80
                        Some(router_config.kv_router_config),
81
                        router_config.busy_threshold,
82
                        target_namespace,
83
                        Arc::new(http_service.clone()),
84
85
                    )
                    .await?;
86
87
88
89
90
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
91
            http_service
92
        }
93
94
95
96
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

97
            let dst_config = DistributedConfig::from_settings(true); // true means static
98
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
99
            let http_service = http_service_builder.build()?;
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
126
            >(card, &client, router_mode, None, kv_chooser.clone())
127
128
129
130
131
132
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
133
            >(card, &client, router_mode, None, kv_chooser)
134
135
            .await?;
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
136

137
            for endpoint_type in EndpointType::all() {
138
                http_service.enable_model_endpoint(endpoint_type, true);
139
140
            }

141
            http_service
142
143
        }
        EngineConfig::StaticFull { engine, model, .. } => {
144
            let http_service = http_service_builder.build()?;
145
146
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
147
148
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
149
150
151

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
152
                http_service.enable_model_endpoint(endpoint_type, true);
153
            }
154
            http_service
155
        }
156
157
        EngineConfig::StaticCore {
            engine: inner_engine,
158
            model,
159
            ..
160
        } => {
161
            let http_service = http_service_builder.build()?;
162
163
164
165
166
            let manager = http_service.model_manager();

            let chat_pipeline = common::build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
167
            >(model.card(), inner_engine.clone())
168
            .await?;
169
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
170

171
172
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
173
                NvCreateCompletionResponse,
174
            >(model.card(), inner_engine)
175
            .await?;
176
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
177
178
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
179
                http_service.enable_model_endpoint(endpoint_type, true);
180
            }
181
            http_service
182
        }
183
    };
184
185
186
187
188
189
190
191
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
192
193
194
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
195
}
196
197
198

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
199
#[allow(clippy::too_many_arguments)]
200
async fn run_watcher(
201
    runtime: DistributedRuntime,
202
    model_manager: Arc<ModelManager>,
203
204
    etcd_client: etcd::Client,
    network_prefix: &str,
205
    router_mode: RouterMode,
206
    kv_router_config: Option<KvRouterConfig>,
207
    busy_threshold: Option<f64>,
208
    target_namespace: Option<String>,
209
    http_service: Arc<HttpService>,
210
) -> anyhow::Result<()> {
211
212
213
214
215
216
217
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
218
219
220
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
221
222
223
224
225
226
227
228
229
230

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

    // Spawn a task to watch for model type changes and update HTTP service endpoints
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
231
            update_http_endpoints(http_service.clone(), model_type);
232
233
234
235
        }
    });

    // Pass the sender to the watcher
236
    let _watcher_task = tokio::spawn(async move {
237
        watch_obj.watch(receiver, target_namespace.as_deref()).await;
238
    });
239

240
241
    Ok(())
}
242
243

/// Updates HTTP service endpoints based on available model types
244
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
245
246
247
248
249
250
251
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
        ModelUpdate::Added(model_type) => match model_type {
            ModelType::Backend => {
252
253
                service.enable_model_endpoint(EndpointType::Chat, true);
                service.enable_model_endpoint(EndpointType::Completion, true);
254
255
            }
            _ => {
256
                service.enable_model_endpoint(model_type.as_endpoint_type(), true);
257
258
259
260
            }
        },
        ModelUpdate::Removed(model_type) => match model_type {
            ModelType::Backend => {
261
262
                service.enable_model_endpoint(EndpointType::Chat, false);
                service.enable_model_endpoint(EndpointType::Completion, false);
263
264
            }
            _ => {
265
                service.enable_model_endpoint(model_type.as_endpoint_type(), false);
266
267
268
269
            }
        },
    }
}