discovery.rs 10.2 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Ryan Olson's avatar
Ryan Olson committed
16
use std::sync::Arc;
17

18
use anyhow::Context as _;
Ryan Olson's avatar
Ryan Olson committed
19
use serde::{Deserialize, Serialize};
20
use tokio::sync::mpsc::Receiver;
Ryan Olson's avatar
Ryan Olson committed
21

Neelay Shah's avatar
Neelay Shah committed
22
use dynamo_runtime::{
23
24
    component::{self, ComponentEndpointInfo},
    pipeline::network::egress::push_router::PushRouter,
Ryan Olson's avatar
Ryan Olson committed
25
    protocols::{self, annotated::Annotated},
26
    raise,
27
28
    slug::Slug,
    transports::etcd::{self, KeyValue, WatchEvent},
29
    DistributedRuntime,
30
};
Ryan Olson's avatar
Ryan Olson committed
31
32

use super::ModelManager;
33
use crate::model_type::ModelType;
Ryan Olson's avatar
Ryan Olson committed
34
use crate::protocols::openai::chat_completions::{
35
    NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
36
};
37
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
38
39
40
41
use crate::{
    key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
    model_card::{self, ModelDeploymentCard},
};
42
use tracing;
43

Ryan Olson's avatar
Ryan Olson committed
44
45
46
47
48
49
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
    /// Public name of the model
    /// This will be used to identify the model in the HTTP service and the value used in an
50
    /// an [OAI ChatRequest][crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest].
Ryan Olson's avatar
Ryan Olson committed
51
52
53
54
    pub name: String,

    /// Component of the endpoint.
    pub endpoint: protocols::Endpoint,
55
56
57

    /// Specifies whether the model is a chat or completion model.s
    pub model_type: ModelType,
Ryan Olson's avatar
Ryan Olson committed
58
59
}

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
impl ModelEntry {
    pub async fn load_mdc(
        &self,
        endpoint_id: protocols::Endpoint,
        etcd_client: etcd::Client,
    ) -> anyhow::Result<ModelDeploymentCard> {
        let kvstore: Box<dyn KeyValueStore> =
            Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
        let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
        let card_key = ModelDeploymentCard::service_name_slug(&self.name);
        match card_store
            .load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key)
            .await
        {
            Ok(Some(mdc)) => Ok(mdc),
            Ok(None) => {
                anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
            }
            Err(err) => {
                anyhow::bail!(
                    "Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
                );
            }
        }
    }
}

#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);

impl ModelNetworkName {
    /// Key to store this model entry in networked key-value store (etcd).
    ///
    /// It looks like this:
    /// ns.cp.ep-694d967ca5efd804
    fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
        ModelNetworkName(
            Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(),
        )
    }

    // We can't do From<&component::Endpoint> here because we also need the lease_id
    pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
        Self::from_parts(
            &endpoint.component().namespace().to_string(),
            &endpoint.component().name(),
            endpoint.name(),
            lease_id,
        )
    }

    pub async fn load_mdc(
        &self,
        endpoint_id: protocols::Endpoint,
        etcd_client: etcd::Client,
    ) -> anyhow::Result<ModelDeploymentCard> {
        let network_name = self;
        let model_entries = etcd_client.kv_get(network_name.to_string(), None).await?;
        if model_entries.is_empty() {
            anyhow::bail!("No ModelEntry in etcd for key {network_name}");
        }
        let entry: ModelEntry =
            serde_json::from_slice(model_entries[0].value()).with_context(|| {
                format!(
                    "Error deserializing JSON. Key={network_name}. JSON={}",
                    model_entries[0].value_str().unwrap_or("INVALID UTF-8")
                )
            })?;
        entry.load_mdc(endpoint_id, etcd_client).await
    }
}

impl From<&ComponentEndpointInfo> for ModelNetworkName {
    fn from(cei: &ComponentEndpointInfo) -> Self {
        Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id)
    }
}

impl std::fmt::Display for ModelNetworkName {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

Ryan Olson's avatar
Ryan Olson committed
144
145
pub struct ModelWatchState {
    pub prefix: String,
146
    pub model_type: ModelType,
Ryan Olson's avatar
Ryan Olson committed
147
148
149
150
    pub manager: ModelManager,
    pub drt: DistributedRuntime,
}

151
pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
152
    tracing::debug!("model watcher started");
Ryan Olson's avatar
Ryan Olson committed
153
154
155

    while let Some(event) = events_rx.recv().await {
        match event {
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            WatchEvent::Put(kv) => {
                let key = match kv.key_str() {
                    Ok(key) => key,
                    Err(err) => {
                        tracing::error!(%err, ?kv, "Invalid UTF8 in model key");
                        continue;
                    }
                };
                tracing::debug!(key, "adding model");

                // model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
                let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
                    Ok(model_entry) => model_entry,
                    Err(err) => {
                        tracing::error!(%err, ?kv, "Invalid JSON in model entry");
                        continue;
                    }
                };
                if state.manager.has_model_any(&model_entry.name) {
                    tracing::trace!(
                        service_name = model_entry.name,
                        "New endpoint for existing model"
                    );
                    continue;
Ryan Olson's avatar
Ryan Olson committed
180
                }
181
182
183
184
185
186
187
188

                match handle_put(model_entry, state.clone()).await {
                    Ok((model_name, model_type)) => {
                        tracing::info!("added {} model: {}", model_type, model_name);
                    }
                    Err(e) => {
                        tracing::error!("error adding model: {}", e);
                    }
Ryan Olson's avatar
Ryan Olson committed
189
                }
190
            }
Ryan Olson's avatar
Ryan Olson committed
191
            WatchEvent::Delete(kv) => match handle_delete(&kv, state.clone()).await {
192
193
                Ok((model_name, model_type)) => {
                    tracing::info!("removed {} model: {}", model_type, model_name);
Ryan Olson's avatar
Ryan Olson committed
194
195
                }
                Err(e) => {
196
                    tracing::error!("error removing model: {}", e);
Ryan Olson's avatar
Ryan Olson committed
197
198
199
200
201
202
                }
            },
        }
    }
}

203
204
205
206
async fn handle_delete(
    kv: &KeyValue,
    state: Arc<ModelWatchState>,
) -> anyhow::Result<(&str, ModelType)> {
Ryan Olson's avatar
Ryan Olson committed
207
    let key = kv.key_str()?;
208
    tracing::debug!(key, "removing model");
Ryan Olson's avatar
Ryan Olson committed
209
210

    let model_name = key.trim_start_matches(&state.prefix);
211
212
213
214
215
216
217

    match state.model_type {
        ModelType::Chat => state.manager.remove_chat_completions_model(model_name)?,
        ModelType::Completion => state.manager.remove_completions_model(model_name)?,
    };

    Ok((model_name, state.model_type))
Ryan Olson's avatar
Ryan Olson committed
218
219
220
221
222
223
}

// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
224
225
226
227
async fn handle_put(
    model_entry: ModelEntry,
    state: Arc<ModelWatchState>,
) -> anyhow::Result<(String, ModelType)> {
228
229
230
231
232
233
234
    if model_entry.model_type != state.model_type {
        raise!(
            "model type mismatch: {} != {}",
            model_entry.model_type,
            state.model_type
        );
    }
Ryan Olson's avatar
Ryan Olson committed
235

236
237
    match state.model_type {
        ModelType::Chat => {
238
            let endpoint_id = model_entry.endpoint.clone();
239
240
            let client = state
                .drt
241
242
243
244
                .namespace(&endpoint_id.namespace)?
                .component(&endpoint_id.component)?
                .endpoint(&endpoint_id.name)
                .client()
245
                .await?;
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

            let Some(etcd_client) = state.drt.etcd_client() else {
                // Should be impossible because we only get here on an etcd event
                anyhow::bail!("Missing etcd_client");
            };
            let mdc = model_entry.load_mdc(endpoint_id, etcd_client).await?;
            if mdc.requires_preprocessing {
                // Note requires_preprocessing is never true in our code right now
                todo!("Ingress-side pre-processing not supported yet");
            } else {
                let push_router = PushRouter::<
                    NvCreateChatCompletionRequest,
                    Annotated<NvCreateChatCompletionStreamResponse>,
                >::from_client(client, Default::default())
                .await?;
                state
                    .manager
                    .add_chat_completions_model(&model_entry.name, Arc::new(push_router))?;
            }
265
266
267
268
269
270
271
        }
        ModelType::Completion => {
            let client = state
                .drt
                .namespace(model_entry.endpoint.namespace)?
                .component(model_entry.endpoint.component)?
                .endpoint(model_entry.endpoint.name)
272
273
274
275
276
277
278
279
280
281
                .client()
                .await?;

            // TODO: Handle pre-processing once it moves ingress-side

            let push_router =
                PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
                    client,
                    Default::default(),
                )
282
283
284
                .await?;
            state
                .manager
285
                .add_completions_model(&model_entry.name, Arc::new(push_router))?;
286
287
        }
    }
Ryan Olson's avatar
Ryan Olson committed
288

289
    Ok((model_entry.name, state.model_type))
Ryan Olson's avatar
Ryan Olson committed
290
}