http.rs 9.42 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
8
    discovery::{ModelManager, ModelUpdate, ModelWatcher, MODEL_ROOT_PATH},
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{self, input::common, EngineConfig},
11
    http::service::service_v2::{self, HttpService},
12
    kv_router::KvRouterConfig,
13
    model_type::ModelType,
14
15
16
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
17
18
    },
};
19
use dynamo_runtime::transports::etcd;
20
use dynamo_runtime::{distributed::DistributedConfig, pipeline::RouterMode};
21
use dynamo_runtime::{DistributedRuntime, Runtime};
22
23

/// Build and run an HTTP service
24
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
25
    let mut http_service_builder = service_v2::HttpService::builder()
26
        .port(engine_config.local_model().http_port())
27
        .with_request_template(engine_config.local_model().request_template());
28

29
    let http_service = match engine_config {
30
        EngineConfig::Dynamic(_) => {
31
32
33
34
35
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
            let etcd_client = distributed_runtime.etcd_client();
            // This allows the /health endpoint to query etcd for active instances
            http_service_builder = http_service_builder.with_etcd_client(etcd_client.clone());
            let http_service = http_service_builder.build()?;
36
37
            match etcd_client {
                Some(ref etcd_client) => {
38
                    let router_config = engine_config.local_model().router_config();
39
                    // Listen for models registering themselves in etcd, add them to HTTP service
40
                    run_watcher(
41
                        distributed_runtime,
42
                        http_service.state().manager_clone(),
43
                        etcd_client.clone(),
44
                        MODEL_ROOT_PATH,
45
                        router_config.router_mode,
46
                        Some(router_config.kv_router_config),
47
                        Arc::new(http_service.clone()),
48
49
                    )
                    .await?;
50
51
52
53
54
                }
                None => {
                    // Static endpoints don't need discovery
                }
            }
55
            http_service
56
        }
57
58
59
60
        EngineConfig::StaticRemote(local_model) => {
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

61
            let dst_config = DistributedConfig::from_settings(true); // true means static
62
            let distributed_runtime = DistributedRuntime::new(runtime.clone(), dst_config).await?;
63
            let http_service = http_service_builder.build()?;
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            let manager = http_service.model_manager();

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?;
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                Some(
                    manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
            >(card, &client, router_mode, kv_chooser.clone())
            .await?;
            manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;

            let completions_engine = entrypoint::build_routed_pipeline::<
                NvCreateCompletionRequest,
                NvCreateCompletionResponse,
            >(card, &client, router_mode, kv_chooser)
            .await?;
            manager.add_completions_model(local_model.display_name(), completions_engine)?;
100

101
102
103
104
105
106
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }

107
            http_service
108
109
        }
        EngineConfig::StaticFull { engine, model, .. } => {
110
            let http_service = http_service_builder.build()?;
111
112
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
113
114
            manager.add_completions_model(model.service_name(), engine.clone())?;
            manager.add_chat_completions_model(model.service_name(), engine)?;
115
116
117
118
119
120
121

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }
122
            http_service
123
        }
124
125
        EngineConfig::StaticCore {
            engine: inner_engine,
126
            model,
127
            ..
128
        } => {
129
            let http_service = http_service_builder.build()?;
130
131
132
133
134
            let manager = http_service.model_manager();

            let chat_pipeline = common::build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
135
            >(model.card(), inner_engine.clone())
136
            .await?;
137
            manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
138

139
140
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
141
                NvCreateCompletionResponse,
142
            >(model.card(), inner_engine)
143
            .await?;
144
            manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
145
146
147
148
149
150
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
                http_service
                    .enable_model_endpoint(endpoint_type, true)
                    .await;
            }
151
            http_service
152
        }
153
    };
154
155
156
157
158
159
160
161
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
162
163
164
    http_service.run(runtime.primary_token()).await?;
    runtime.shutdown(); // Cancel primary token
    Ok(())
165
}
166
167
168
169

/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher(
170
    runtime: DistributedRuntime,
171
    model_manager: Arc<ModelManager>,
172
173
    etcd_client: etcd::Client,
    network_prefix: &str,
174
    router_mode: RouterMode,
175
    kv_router_config: Option<KvRouterConfig>,
176
    http_service: Arc<HttpService>,
177
) -> anyhow::Result<()> {
178
    let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
179
180
181
    tracing::info!("Watching for remote model at {network_prefix}");
    let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
    let (_prefix, _watcher, receiver) = models_watcher.dissolve();
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);

    watch_obj.set_notify_on_model_update(tx);

    // Spawn a task to watch for model type changes and update HTTP service endpoints
    let _endpoint_enabler_task = tokio::spawn(async move {
        while let Some(model_type) = rx.recv().await {
            tracing::debug!("Received model type update: {:?}", model_type);
            update_http_endpoints(http_service.clone(), model_type).await;
        }
    });

    // Pass the sender to the watcher
197
198
199
    let _watcher_task = tokio::spawn(async move {
        watch_obj.watch(receiver).await;
    });
200

201
202
    Ok(())
}
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
        ModelUpdate::Added(model_type) => match model_type {
            ModelType::Backend => {
                service
                    .enable_model_endpoint(EndpointType::Chat, true)
                    .await;
                service
                    .enable_model_endpoint(EndpointType::Completion, true)
                    .await;
            }
            _ => {
                service
                    .enable_model_endpoint(model_type.as_endpoint_type(), true)
                    .await;
            }
        },
        ModelUpdate::Removed(model_type) => match model_type {
            ModelType::Backend => {
                service
                    .enable_model_endpoint(EndpointType::Chat, false)
                    .await;
                service
                    .enable_model_endpoint(EndpointType::Completion, false)
                    .await;
            }
            _ => {
                service
                    .enable_model_endpoint(model_type.as_endpoint_type(), false)
                    .await;
            }
        },
    }
}