http.rs 9.65 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use crate::{
7
    discovery::{ModelManager, ModelUpdate, ModelWatcher},
8
    endpoint_type::EndpointType,
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
11
    http::service::service_v2::{self, HttpService},
12
    namespace::NamespaceFilter,
13
14
15
    types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
16
17
    },
};
18
use dynamo_runtime::DistributedRuntime;
19
use dynamo_runtime::metrics::MetricsHierarchy;
20
21

/// Build and run an HTTP service
22
23
24
25
pub async fn run(
    distributed_runtime: DistributedRuntime,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
Graham King's avatar
Graham King committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    let local_model = engine_config.local_model();
    let mut http_service_builder = match (local_model.tls_cert_path(), local_model.tls_key_path()) {
        (Some(tls_cert_path), Some(tls_key_path)) => {
            if !tls_cert_path.exists() {
                anyhow::bail!("TLS certificate not found: {}", tls_cert_path.display());
            }
            if !tls_key_path.exists() {
                anyhow::bail!("TLS key not found: {}", tls_key_path.display());
            }
            service_v2::HttpService::builder()
                .enable_tls(true)
                .tls_cert_path(Some(tls_cert_path.to_path_buf()))
                .tls_key_path(Some(tls_key_path.to_path_buf()))
                .port(local_model.http_port())
        }
        (None, None) => service_v2::HttpService::builder().port(local_model.http_port()),
        (_, _) => {
            // CLI should prevent us ever getting here
            anyhow::bail!(
                "Both --tls-cert-path and --tls-key-path must be provided together to enable TLS"
            );
        }
    };
49
50
51
    if let Some(http_host) = local_model.http_host() {
        http_service_builder = http_service_builder.host(http_host);
    }
52
53
    http_service_builder =
        http_service_builder.cancel_token(Some(distributed_runtime.primary_token()));
Graham King's avatar
Graham King committed
54
55
    http_service_builder =
        http_service_builder.with_request_template(engine_config.local_model().request_template());
56
57
58
59
60
61
62
63
64
65
    // Inject the DRT's metrics registry so that component-scoped metrics
    // (e.g. KvIndexerMetrics) are exposed (default port 8000 if not overridden).
    http_service_builder =
        http_service_builder.drt_metrics(Some(distributed_runtime.get_metrics_registry().clone()));

    // Wire DRT discovery so that router metrics (dynamo_router_*) are registered
    // with the instance_id as the router_id label.
    http_service_builder =
        http_service_builder.drt_discovery(Some(distributed_runtime.discovery()));

66
    let http_service = match engine_config {
67
68
        EngineConfig::Dynamic {
            ref model,
69
            ref chat_engine_factory,
70
        } => {
71
72
73
            // Pass the discovery client so the /health endpoint can query active instances
            http_service_builder =
                http_service_builder.discovery(Some(distributed_runtime.discovery()));
74
            let http_service = http_service_builder.build()?;
75

76
            let router_config = model.router_config();
77
            let migration_limit = model.migration_limit();
78
            // Listen for models registering themselves, add them to HTTP service
79
80
81
82
83
            // Create namespace filter from model configuration
            let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
                model.namespace(),
                model.namespace_prefix(),
            );
84
            run_watcher(
85
                distributed_runtime.clone(),
86
                http_service.state().manager_clone(),
87
                router_config.clone(),
88
                migration_limit,
89
                namespace_filter,
90
91
                Arc::new(http_service.clone()),
                http_service.state().metrics_clone(),
92
                chat_engine_factory.clone(),
93
94
            )
            .await?;
95
            http_service
96
        }
97
        EngineConfig::InProcessText { engine, model, .. } => {
98
            let http_service = http_service_builder.build()?;
99
100
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
            let manager = http_service.model_manager();
101
            let checksum = model.card().mdcsum();
102
103
            manager.add_completions_model(model.display_name(), checksum, engine.clone())?;
            manager.add_chat_completions_model(model.display_name(), checksum, engine)?;
104
105
106

            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
107
                http_service.enable_model_endpoint(endpoint_type, true);
108
            }
109
            http_service
110
        }
111
        EngineConfig::InProcessTokens {
112
            engine: inner_engine,
113
            model,
114
            ..
115
        } => {
116
            let http_service = http_service_builder.build()?;
117
            let manager = http_service.model_manager();
118
            let checksum = model.card().mdcsum();
119

Nikita's avatar
Nikita committed
120
121
122
123
124
125
            let tokenizer = model.card().tokenizer()?;
            let chat_pipeline = common::build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
            >(model.card(), inner_engine.clone(), tokenizer.clone())
            .await?;
126
            manager.add_chat_completions_model(model.display_name(), checksum, chat_pipeline)?;
127

128
129
            let cmpl_pipeline = common::build_pipeline::<
                NvCreateCompletionRequest,
130
                NvCreateCompletionResponse,
Nikita's avatar
Nikita committed
131
            >(model.card(), inner_engine, tokenizer)
132
            .await?;
133
            manager.add_completions_model(model.display_name(), checksum, cmpl_pipeline)?;
134
135
            // Enable all endpoints
            for endpoint_type in EndpointType::all() {
136
                http_service.enable_model_endpoint(endpoint_type, true);
137
            }
138
            http_service
139
        }
140
    };
141
142
143
144
145
146
147
148
    tracing::debug!(
        "Supported routes: {:?}",
        http_service
            .route_docs()
            .iter()
            .map(|rd| rd.to_string())
            .collect::<Vec<String>>()
    );
149

150
151
152
    http_service
        .run(distributed_runtime.primary_token())
        .await?;
153

154
    distributed_runtime.shutdown(); // Cancel primary token
155
    Ok(())
156
}
157

158
/// Spawns a task that watches for new models in store,
159
/// and registers them with the ModelManager so that the HTTP service can use them.
160
#[allow(clippy::too_many_arguments)]
161
async fn run_watcher(
162
    runtime: DistributedRuntime,
163
    model_manager: Arc<ModelManager>,
164
    router_config: RouterConfig,
165
    migration_limit: u32,
166
    namespace_filter: NamespaceFilter,
167
    http_service: Arc<HttpService>,
168
    metrics: Arc<crate::http::service::metrics::Metrics>,
169
    chat_engine_factory: Option<ChatEngineFactoryCallback>,
170
) -> anyhow::Result<()> {
171
172
173
174
    let mut watch_obj = ModelWatcher::new(
        runtime.clone(),
        model_manager,
        router_config,
175
        migration_limit,
176
        chat_engine_factory,
177
        metrics.clone(),
178
    );
179
    tracing::debug!("Waiting for remote model");
180
181
182
183
184
185
186
    let discovery = runtime.discovery();
    let discovery_stream = discovery
        .list_and_watch(
            dynamo_runtime::discovery::DiscoveryQuery::AllModels,
            Some(runtime.primary_token()),
        )
        .await?;
187
188
189
190
191

    // Create a channel to receive model type updates
    let (tx, mut rx) = tokio::sync::mpsc::channel(32);
    watch_obj.set_notify_on_model_update(tx);

192
    // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
193
    let _endpoint_enabler_task = tokio::spawn(async move {
194
195
196
        while let Some(model_update) = rx.recv().await {
            update_http_endpoints(http_service.clone(), model_update.clone());
            update_model_metrics(model_update, metrics.clone());
197
198
199
        }
    });

200
    // Pass the discovery stream to the watcher
201
    let _watcher_task = tokio::spawn(async move {
202
        watch_obj.watch(discovery_stream, namespace_filter).await;
203
    });
204

205
206
    Ok(())
}
207
208

/// Updates HTTP service endpoints based on available model types
209
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
210
211
212
213
214
    tracing::debug!(
        "Updating HTTP service endpoints for model type: {:?}",
        model_type
    );
    match model_type {
215
        ModelUpdate::Added(card) => {
216
            // Handle all supported endpoint types, not just the first one
217
            for endpoint_type in card.model_type.as_endpoint_types() {
218
                service.enable_model_endpoint(endpoint_type, true);
219
            }
220
        }
221
        ModelUpdate::Removed(card) => {
222
            // Handle all supported endpoint types, not just the first one
223
            for endpoint_type in card.model_type.as_endpoint_types() {
224
                service.enable_model_endpoint(endpoint_type, false);
225
            }
226
        }
227
228
    }
}
229
230

/// Updates metrics for model type changes
231
fn update_model_metrics(
232
233
234
235
    model_type: ModelUpdate,
    metrics: Arc<crate::http::service::metrics::Metrics>,
) {
    match model_type {
236
237
238
239
        ModelUpdate::Added(card) => {
            tracing::debug!("Updating metrics for added model: {}", card.display_name);
            if let Err(err) = metrics.update_metrics_from_mdc(&card) {
                tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
240
241
            }
        }
242
243
        ModelUpdate::Removed(card) => {
            tracing::debug!(model_name = card.display_name, "Model removed");
244
245
246
247
248
            // Note: Metrics are typically not removed to preserve historical data
            // This matches the behavior in the polling task
        }
    }
}