http.rs 10.7 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
8
    discovery::{ModelManager, ModelUpdate, ModelWatcher, MODEL_ROOT_PATH},
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, input::common, EngineConfig},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_type::ModelType,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
use dynamo_runtime::transports::etcd;
20
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
21
use dynamo_runtime::{DistributedRuntime, Runtime};
22
23

/// Build and run an HTTP service
24
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
        EngineConfig::Dynamic(_) => {
56
57
58
59
60
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
61
62
            match etcd_client {
                Some(ref etcd_client) => {
63
                    let router_config = engine_config.local_model().router_config();
64
                    // Listen for models registering themselves in etcd, add them to HTTP service
65
                    run_watcher(
66
                        distributed_runtime,
67
                        http_service.state().manager_clone(),
68
                        etcd_client.clone(),
69
                        MODEL_ROOT_PATH,
70
                        router_config.router_mode,
71
                        Some(router_config.kv_router_config),
72
                        router_config.busy_threshold,
73
                        Arc::new(http_service.clone()),
74
75
                    )
                    .await?;
76
77
78
79
80
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
81
            http_service
82
        }
83
84
85
86
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

87
            let dst_config = DistributedConfig::from_settings(true); // true means static
88
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
89
            let http_service = http_service_builder.build()?;
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
116
            >(card, &client, router_mode, None, kv_chooser.clone())
117
118
119
120
121
122
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
123
            >(card, &client, router_mode, None, kv_chooser)
124
125
            .await?;
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
126

127
128
129
130
131
132
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }

133
            http_service
134
135
        }
        EngineConfig::StaticFull { engine, model, .. } => {
136
            let http_service = http_service_builder.build()?;
137
138
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
139
140
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
141
142
143
144
145
146
147

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }
148
            http_service
149
        }
150
151
        EngineConfig::StaticCore {
            engine: inner_engine,
152
            model,
153
            ..
154
        } => {
155
            let http_service = http_service_builder.build()?;
156
157
158
159
160
            let manager = http_service.model_manager();

            let chat_pipeline = common::build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
161
            >(model.card(), inner_engine.clone())
162
            .await?;
163
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
164

165
166
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
167
                NvCreateCompletionResponse,
168
            >(model.card(), inner_engine)
169
            .await?;
170
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
171
172
173
174
175
176
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }
177
            http_service
178
        }
179
    };
180
181
182
183
184
185
186
187
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
188
189
190
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
191
}
192
193
194

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
195
#[allow(clippy::too_many_arguments)]
196
async fn run_watcher(
197
    runtime: DistributedRuntime,
198
    model_manager: Arc<ModelManager>,
199
200
    etcd_client: etcd::Client,
    network_prefix: &str,
201
    router_mode: RouterMode,
202
    kv_router_config: Option<KvRouterConfig>,
203
    busy_threshold: Option<f64>,
204
    http_service: Arc<HttpService>,
205
) -> anyhow::Result<()> {
206
207
208
209
210
211
212
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
213
214
215
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

    // Spawn a task to watch for model type changes and update HTTP service endpoints
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
            update_http_endpoints(http_service.clone(), model_type).await;
        }
    });

    // Pass the sender to the watcher
231
232
233
    let _watcher_task = tokio::spawn(async move {
        watch_obj.watch(receiver).await;
    });
234

235
236
    Ok(())
}
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
        ModelUpdate::Added(model_type) => match model_type {
            ModelType::Backend => {
                service
                    .enable_model_endpoint(EndpointType::Chat, true)
                    .await;
                service
                    .enable_model_endpoint(EndpointType::Completion, true)
                    .await;
            }
            _ => {
                service
                    .enable_model_endpoint(model_type.as_endpoint_type(), true)
                    .await;
            }
        },
        ModelUpdate::Removed(model_type) => match model_type {
            ModelType::Backend => {
                service
                    .enable_model_endpoint(EndpointType::Chat, false)
                    .await;
                service
                    .enable_model_endpoint(EndpointType::Completion, false)
                    .await;
            }
            _ => {
                service
                    .enable_model_endpoint(model_type.as_endpoint_type(), false)
                    .await;
            }
        },
    }
}