http.rs 10.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, EngineConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_type::ModelType,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
use dynamo_runtime::transports::etcd;
20
use dynamo_runtime::{DistributedRuntime, Runtime};
21
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
22
23

/// Build and run an HTTP service
24
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
48
49
50
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
Graham King's avatar
Graham King committed
51
52
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
53

54
    let http_service = match engine_config {
55
        EngineConfig::Dynamic(_) => {
56
57
58
59
60
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
61
62
            match etcd_client {
                Some(ref etcd_client) => {
63
                    let router_config = engine_config.local_model().router_config();
64
                    // Listen for models registering themselves in etcd, add them to HTTP service
65
                    run_watcher(
66
                        distributed_runtime,
67
                        http_service.state().manager_clone(),
68
                        etcd_client.clone(),
69
                        MODEL_ROOT_PATH,
70
                        router_config.router_mode,
71
                        Some(router_config.kv_router_config),
72
                        router_config.busy_threshold,
73
                        Arc::new(http_service.clone()),
74
75
                    )
                    .await?;
76
77
78
79
80
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
81
            http_service
82
        }
83
84
85
86
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

87
            let dst_config = DistributedConfig::from_settings(true); // true means static
88
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
89
            let http_service = http_service_builder.build()?;
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
116
            >(card, &client, router_mode, None, kv_chooser.clone())
117
118
119
120
121
122
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
123
            >(card, &client, router_mode, None, kv_chooser)
124
125
            .await?;
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
126

127
            for endpoint_type in EndpointType::all() {
128
                http_service.enable_model_endpoint(endpoint_type, true);
129
130
            }

131
            http_service
132
133
        }
        EngineConfig::StaticFull { engine, model, .. } => {
134
            let http_service = http_service_builder.build()?;
135
136
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
137
138
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
139
140
141

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
142
                http_service.enable_model_endpoint(endpoint_type, true);
143
            }
144
            http_service
145
        }
146
147
        EngineConfig::StaticCore {
            engine: inner_engine,
148
            model,
149
            ..
150
        } => {
151
            let http_service = http_service_builder.build()?;
152
153
154
155
156
            let manager = http_service.model_manager();

            let chat_pipeline = common::build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
157
            >(model.card(), inner_engine.clone())
158
            .await?;
159
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
160

161
162
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
163
                NvCreateCompletionResponse,
164
            >(model.card(), inner_engine)
165
            .await?;
166
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
167
168
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
169
                http_service.enable_model_endpoint(endpoint_type, true);
170
            }
171
            http_service
172
        }
173
    };
174
175
176
177
178
179
180
181
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
182
183
184
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
185
}
186
187
188

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
189
#[allow(clippy::too_many_arguments)]
190
async fn run_watcher(
191
    runtime: DistributedRuntime,
192
    model_manager: Arc<ModelManager>,
193
194
    etcd_client: etcd::Client,
    network_prefix: &str,
195
    router_mode: RouterMode,
196
    kv_router_config: Option<KvRouterConfig>,
197
    busy_threshold: Option<f64>,
198
    http_service: Arc<HttpService>,
199
) -> anyhow::Result<()> {
200
201
202
203
204
205
206
    let mut watch_obj = ModelWatcher::new(
        runtime,
        model_manager,
        router_mode,
        kv_router_config,
        busy_threshold,
    );
207
208
209
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
210
211
212
213
214
215
216
217
218
219

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

    // Spawn a task to watch for model type changes and update HTTP service endpoints
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
220
            update_http_endpoints(http_service.clone(), model_type);
221
222
223
224
        }
    });

    // Pass the sender to the watcher
225
226
227
    let _watcher_task = tokio::spawn(async move {
        watch_obj.watch(receiver).await;
    });
228

229
230
    Ok(())
}
231
232

/// Updates HTTP service endpoints based on available model types
233
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
234
235
236
237
238
239
240
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
        ModelUpdate::Added(model_type) => match model_type {
            ModelType::Backend => {
241
242
                service.enable_model_endpoint(EndpointType::Chat, true);
                service.enable_model_endpoint(EndpointType::Completion, true);
243
244
            }
            _ => {
245
                service.enable_model_endpoint(model_type.as_endpoint_type(), true);
246
247
248
249
            }
        },
        ModelUpdate::Removed(model_type) => match model_type {
            ModelType::Backend => {
250
251
                service.enable_model_endpoint(EndpointType::Chat, false);
                service.enable_model_endpoint(EndpointType::Completion, false);
252
253
            }
            _ => {
254
                service.enable_model_endpoint(model_type.as_endpoint_type(), false);
255
256
257
258
            }
        },
    }
}